Skip to content

Commit f3cdfdb

Browse files
committed
Added track_model_request for litellm models
1 parent b4b271e commit f3cdfdb

3 files changed

Lines changed: 12 additions & 1 deletion

File tree

sygra/core/models/lite_llm/azure_openai_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(self, model_config: dict[str, Any]) -> None:
3333
def _get_model_prefix(self) -> str:
3434
return "azure"
3535

36+
@track_model_request
3637
async def _generate_native_structured_output(
3738
self,
3839
input: ChatPromptValue,
@@ -68,6 +69,7 @@ async def _generate_native_structured_output(
6869
api_version=self.api_version,
6970
**all_params,
7071
)
72+
self._extract_token_usage(completion)
7173
resp_text = completion.choices[0].model_dump()["message"]["content"]
7274
tool_calls = completion.choices[0].model_dump()["message"]["tool_calls"]
7375
# Check if the request was successful based on the response status
@@ -159,6 +161,7 @@ async def _generate_text(
159161
api_version=self.api_version,
160162
**self.generation_params,
161163
)
164+
self._extract_token_usage(completion)
162165
resp_text = completion.choices[0].model_dump()["message"]["content"]
163166
tool_calls = completion.choices[0].model_dump()["message"]["tool_calls"]
164167
except openai.RateLimitError as e:

sygra/core/models/lite_llm/openai_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self, model_config: dict[str, Any]) -> None:
2929
self.model_config = model_config
3030
self.model_name = self.model_config.get("model", self.name())
3131

32+
@track_model_request
3233
async def _generate_native_structured_output(
3334
self,
3435
input: ChatPromptValue,
@@ -63,6 +64,7 @@ async def _generate_native_structured_output(
6364
api_key=model_params.auth_token,
6465
**all_params,
6566
)
67+
self._extract_token_usage(completion)
6668
resp_text = completion.choices[0].model_dump()["message"]["content"]
6769
tool_calls = completion.choices[0].model_dump()["message"]["tool_calls"]
6870
# Check if the request was successful based on the response status
@@ -153,6 +155,7 @@ async def _generate_text(
153155
api_key=model_params.auth_token,
154156
**self.generation_params,
155157
)
158+
self._extract_token_usage(completion)
156159
resp_text = completion.choices[0].model_dump()["message"]["content"]
157160
tool_calls = completion.choices[0].model_dump()["message"]["tool_calls"]
158161
except openai.RateLimitError as e:

sygra/core/models/lite_llm/vllm_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _validate_completions_api_model_support(self) -> None:
3333
def _get_model_prefix(self) -> str:
3434
return "hosted_vllm"
3535

36+
@track_model_request
3637
async def _generate_native_structured_output(
3738
self,
3839
input: ChatPromptValue,
@@ -72,6 +73,7 @@ async def _generate_native_structured_output(
7273
api_key=model_params.auth_token,
7374
**extra_params,
7475
)
76+
self._extract_token_usage(completion)
7577
resp_text = completion.choices[0].model_dump()["text"]
7678
else:
7779
# Convert input to messages
@@ -84,6 +86,7 @@ async def _generate_native_structured_output(
8486
api_key=model_params.auth_token,
8587
**extra_params,
8688
)
89+
self._extract_token_usage(completion)
8790
resp_text = completion.choices[0].model_dump()["message"]["content"]
8891
tool_calls = completion.choices[0].model_dump()["message"]["tool_calls"]
8992

@@ -157,6 +160,7 @@ async def _generate_response(
157160
api_key=model_params.auth_token,
158161
**self.generation_params,
159162
)
163+
self._extract_token_usage(completion)
160164
resp_text = completion.choices[0].model_dump()["text"]
161165
else:
162166
# Convert input to messages
@@ -169,6 +173,7 @@ async def _generate_response(
169173
api_key=model_params.auth_token,
170174
**self.generation_params,
171175
)
176+
self._extract_token_usage(completion)
172177
resp_text = completion.choices[0].model_dump()["message"]["content"]
173178
tool_calls = completion.choices[0].model_dump()["message"]["tool_calls"]
174179
# TODO: Test rate limit handling for vllm
@@ -181,7 +186,7 @@ async def _generate_response(
181186
logger.error(f"vLLM request failed with error: {e.message}")
182187
ret_code = e.status_code
183188
except Exception as x:
184-
resp_text = f"{constants.ERROR_PREFIX} Http request failed {x}"
189+
resp_text = f"{constants.ERROR_PREFIX} vLLM request failed {x}"
185190
logger.error(resp_text)
186191
rcode = self._get_status_from_body(x)
187192
if constants.ELEMAI_JOB_DOWN in resp_text or constants.CONNECTION_ERROR in resp_text:

0 commit comments

Comments
 (0)