Ajitg25 jaishankar101 commited on
Commit
68d22e0
·
1 Parent(s): 43879f6

Run all 3 tasks, clamp scores to (0,1) exclusive (#6)

Browse files

- Run all 3 tasks, clamp scores to (0,1) exclusive (6cbbcf933fbe4805ec252506be7344190b293085)


Co-authored-by: Jai Shankar K S <jaishankar101@users.noreply.huggingface.co>

Files changed (1) hide show
  1. inference.py +32 -21
inference.py CHANGED
@@ -31,8 +31,8 @@ API_KEY = os.environ["API_KEY"]
31
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
32
  SERVER_URL = os.getenv("OPENENV_SERVER_URL", "http://localhost:7860")
33
 
34
- TASK_NAME = os.getenv("SHOP_SKU_TASK", "easy")
35
  BENCHMARK = "shop_sku_manager"
 
36
  MAX_STEPS = 30
37
  TEMPERATURE = 0.3
38
  MAX_TOKENS = 200
@@ -66,6 +66,11 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
66
  )
67
 
68
 
 
 
 
 
 
69
  # ---------------------------------------------------------------------------
70
  # LLM-powered ordering agent
71
  # ---------------------------------------------------------------------------
@@ -140,41 +145,28 @@ def get_order(client: OpenAI, obs) -> OrderAction:
140
 
141
 
142
  # ---------------------------------------------------------------------------
143
- # Main episode loop
144
  # ---------------------------------------------------------------------------
145
 
146
- async def main() -> None:
147
- print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
148
- print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
149
- print(f"[DEBUG] SERVER_URL={SERVER_URL}", flush=True)
150
- print(f"[DEBUG] API_KEY set={bool(API_KEY)}", flush=True)
151
-
152
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
153
-
154
- env = ShopSKUManagerEnv(base_url=SERVER_URL)
155
-
156
  rewards: List[float] = []
157
  steps_taken = 0
158
  score = 0.0
159
  success = False
160
 
161
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
162
 
 
163
  try:
164
- print("[DEBUG] Calling env.reset()...", flush=True)
165
  result = await env.reset()
166
- print(f"[DEBUG] env.reset() done. done={result.done}", flush=True)
167
 
168
  for step in range(1, MAX_STEPS + 1):
169
  if result.done:
170
- print(f"[DEBUG] Episode done at step {step}", flush=True)
171
  break
172
 
173
  obs = result.observation
174
- print(f"[DEBUG] Step {step}: calling LLM...", flush=True)
175
  action = get_order(client, obs)
176
  action_str = json.dumps(action.model_dump(), separators=(",", ":"))
177
- print(f"[DEBUG] Step {step}: LLM returned, calling env.step()...", flush=True)
178
 
179
  result = await env.step(action)
180
 
@@ -191,13 +183,16 @@ async def main() -> None:
191
  break
192
 
193
  if rewards:
194
- score = sum(rewards) / MAX_TOTAL_REWARD
195
- score = min(max(score, 0.0), 1.0)
 
 
196
  success = score >= SUCCESS_SCORE_THRESHOLD
197
 
198
  except Exception as e:
199
- print(f"[DEBUG] Episode error: {e}", flush=True)
200
  traceback.print_exc()
 
201
 
202
  finally:
203
  try:
@@ -207,5 +202,21 @@ async def main() -> None:
207
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
208
 
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if __name__ == "__main__":
211
  asyncio.run(main())
 
31
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
32
  SERVER_URL = os.getenv("OPENENV_SERVER_URL", "http://localhost:7860")
33
 
 
34
  BENCHMARK = "shop_sku_manager"
35
+ TASKS = ["easy", "medium", "hard"]
36
  MAX_STEPS = 30
37
  TEMPERATURE = 0.3
38
  MAX_TOKENS = 200
 
66
  )
67
 
68
 
69
+ def clamp_score(raw: float) -> float:
70
+ """Clamp score to strictly between 0 and 1 (exclusive)."""
71
+ return min(max(raw, 0.01), 0.99)
72
+
73
+
74
  # ---------------------------------------------------------------------------
75
  # LLM-powered ordering agent
76
  # ---------------------------------------------------------------------------
 
145
 
146
 
147
  # ---------------------------------------------------------------------------
148
+ # Run one task (one [START] / [END] block)
149
  # ---------------------------------------------------------------------------
150
 
151
+ async def run_task(client: OpenAI, task: str) -> None:
 
 
 
 
 
 
 
 
 
152
  rewards: List[float] = []
153
  steps_taken = 0
154
  score = 0.0
155
  success = False
156
 
157
+ log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
158
 
159
+ env = ShopSKUManagerEnv(base_url=SERVER_URL)
160
  try:
 
161
  result = await env.reset()
 
162
 
163
  for step in range(1, MAX_STEPS + 1):
164
  if result.done:
 
165
  break
166
 
167
  obs = result.observation
 
168
  action = get_order(client, obs)
169
  action_str = json.dumps(action.model_dump(), separators=(",", ":"))
 
170
 
171
  result = await env.step(action)
172
 
 
183
  break
184
 
185
  if rewards:
186
+ raw_score = sum(rewards) / MAX_TOTAL_REWARD
187
+ score = clamp_score(raw_score)
188
+ else:
189
+ score = 0.01
190
  success = score >= SUCCESS_SCORE_THRESHOLD
191
 
192
  except Exception as e:
193
+ print(f"[DEBUG] Task {task} error: {e}", flush=True)
194
  traceback.print_exc()
195
+ score = 0.01
196
 
197
  finally:
198
  try:
 
202
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
203
 
204
 
205
+ # ---------------------------------------------------------------------------
206
+ # Main — run all 3 tasks
207
+ # ---------------------------------------------------------------------------
208
+
209
+ async def main() -> None:
210
+ print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
211
+ print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
212
+ print(f"[DEBUG] SERVER_URL={SERVER_URL}", flush=True)
213
+ print(f"[DEBUG] API_KEY set={bool(API_KEY)}", flush=True)
214
+
215
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
216
+
217
+ for task in TASKS:
218
+ await run_task(client, task)
219
+
220
+
221
  if __name__ == "__main__":
222
  asyncio.run(main())