Skip to content

Commit 68c1dd6

Browse files
updates from feedback
1 parent 03524a4 commit 68c1dd6

9 files changed

Lines changed: 131 additions & 141 deletions

File tree

python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import sys
77
from collections.abc import Sequence
8+
from contextlib import suppress
89
from typing import Any, ClassVar, Generic, TypedDict
910

1011
from agent_framework import (
@@ -13,8 +14,9 @@
1314
Embedding,
1415
EmbeddingGenerationOptions,
1516
GeneratedEmbeddings,
17+
UsageDetails,
18+
load_settings,
1619
)
17-
from agent_framework._settings import load_settings
1820
from agent_framework.observability import EmbeddingTelemetryLayer
1921
from azure.ai.inference.aio import EmbeddingsClient, ImageEmbeddingsClient
2022
from azure.ai.inference.models import ImageEmbeddingInput
@@ -162,8 +164,10 @@ def __init__(
162164

163165
async def close(self) -> None:
164166
"""Close the underlying SDK clients and release resources."""
165-
await self._text_client.close()
166-
await self._image_client.close()
167+
with suppress(Exception):
168+
await self._text_client.close()
169+
with suppress(Exception):
170+
await self._image_client.close()
167171

168172
async def __aenter__(self) -> RawAzureAIInferenceEmbeddingClient[AzureAIInferenceEmbeddingOptionsT]:
169173
"""Enter the async context manager."""
@@ -204,10 +208,6 @@ async def get_embeddings(
204208
return GeneratedEmbeddings([], options=options) # type: ignore[reportReturnType]
205209

206210
opts: dict[str, Any] = dict(options) if options else {}
207-
text_model = opts.get("model_id") or self.model_id
208-
image_model = opts.get("image_model_id") or self.image_model_id
209-
if not text_model:
210-
raise ValueError("model_id is required")
211211

212212
# Separate text and image inputs, tracking original indices.
213213
text_items: list[tuple[int, str]] = []
@@ -249,12 +249,13 @@ async def get_embeddings(
249249
common_kwargs["model_extras"] = extra_parameters
250250

251251
# Allocate results array.
252-
results: list[Embedding[list[float]] | None] = [None] * len(values)
253-
total_prompt_tokens = 0
254-
total_completion_tokens = 0
252+
embeddings: list[Embedding[list[float]] | None] = [None] * len(values)
253+
usage_details: UsageDetails = {"input_token_count": 0, "output_token_count": 0}
255254

256255
# Embed text inputs.
257256
if text_items:
257+
if not (text_model := opts.get("model_id") or self.model_id):
258+
raise ValueError("An model_id is required, either in the client or options, for text inputs.")
258259
text_inputs = [t for _, t in text_items]
259260
response = await self._text_client.embed(
260261
input=text_inputs,
@@ -263,18 +264,19 @@ async def get_embeddings(
263264
)
264265
for i, item in enumerate(response.data):
265266
original_idx = text_items[i][0]
266-
vector: list[float] = [float(v) for v in item.embedding]
267-
results[original_idx] = Embedding(
268-
vector=vector,
269-
dimensions=len(vector),
267+
embeddings[original_idx] = Embedding(
268+
vector=item.embedding,
269+
dimensions=len(item.embedding),
270270
model_id=response.model or text_model,
271271
)
272272
if response.usage:
273-
total_prompt_tokens += response.usage.prompt_tokens
274-
total_completion_tokens += getattr(response.usage, "completion_tokens", 0) or 0
273+
usage_details["input_token_count"] += response.usage.prompt_tokens
274+
usage_details["output_token_count"] += getattr(response.usage, "completion_tokens", 0) or 0
275275

276276
# Embed image inputs.
277277
if image_items:
278+
if not (image_model := opts.get("image_model_id") or self.image_model_id):
279+
raise ValueError("An image_model_id is required, either in the client or options, for image inputs.")
278280
image_inputs = [img for _, img in image_items]
279281
response = await self._image_client.embed(
280282
input=image_inputs,
@@ -283,25 +285,16 @@ async def get_embeddings(
283285
)
284286
for i, item in enumerate(response.data):
285287
original_idx = image_items[i][0]
286-
img_vector: list[float] = [float(v) for v in item.embedding]
287-
results[original_idx] = Embedding(
288-
vector=img_vector,
289-
dimensions=len(img_vector),
288+
embeddings[original_idx] = Embedding(
289+
vector=item.embedding,
290+
dimensions=len(item.embedding),
290291
model_id=response.model or image_model,
291292
)
292293
if response.usage:
293-
total_prompt_tokens += response.usage.prompt_tokens
294-
total_completion_tokens += getattr(response.usage, "completion_tokens", 0) or 0
294+
usage_details["input_token_count"] += response.usage.prompt_tokens
295+
usage_details["output_token_count"] += getattr(response.usage, "completion_tokens", 0) or 0
295296

296-
embeddings = [r for r in results if r is not None]
297-
298-
usage_dict: dict[str, Any] | None = None
299-
if total_prompt_tokens > 0 or total_completion_tokens > 0:
300-
usage_dict = {"prompt_tokens": total_prompt_tokens}
301-
if total_completion_tokens > 0:
302-
usage_dict["completion_tokens"] = total_completion_tokens
303-
304-
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) # type: ignore[reportReturnType]
297+
return GeneratedEmbeddings(embeddings, options=options, usage=usage_details) # type: ignore[reportReturnType]
305298

306299

307300
class AzureAIInferenceEmbeddingClient(

python/packages/azure-ai/tests/azure_ai/test_azure_ai_inference_embedding_client.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ async def test_model_override_in_options(
169169
call_kwargs = mock_text_client.embed.call_args
170170
assert call_kwargs.kwargs["model"] == "custom-model"
171171

172-
async def test_unsupported_content_type_raises(
173-
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]
174-
) -> None:
172+
async def test_unsupported_content_type_raises(self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]) -> None:
175173
"""Non-text, non-image Content raises ValueError."""
176174
error_content = Content("error", message="fail")
177175
with pytest.raises(ValueError, match="Unsupported Content type"):
@@ -181,12 +179,10 @@ async def test_usage_metadata(
181179
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
182180
) -> None:
183181
"""Usage metadata is populated from the response."""
184-
mock_text_client.embed.return_value = _make_embed_response(
185-
[[0.1, 0.2]], prompt_tokens=42
186-
)
182+
mock_text_client.embed.return_value = _make_embed_response([[0.1, 0.2]], prompt_tokens=42)
187183
result = await raw_client.get_embeddings(["hello"])
188184
assert result.usage is not None
189-
assert result.usage["prompt_tokens"] == 42
185+
assert result.usage["input_token_count"] == 42
190186

191187
def test_service_url(self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]) -> None:
192188
"""service_url returns the configured endpoint."""
@@ -229,9 +225,7 @@ def test_image_model_id_from_env(self) -> None:
229225
assert client.model_id == "text-model"
230226
assert client.image_model_id == "image-model"
231227

232-
def test_image_model_id_explicit(
233-
self, mock_text_client: AsyncMock, mock_image_client: AsyncMock
234-
) -> None:
228+
def test_image_model_id_explicit(self, mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> None:
235229
"""image_model_id can be set explicitly."""
236230
client = RawAzureAIInferenceEmbeddingClient(
237231
model_id="text-model",
@@ -277,9 +271,7 @@ async def test_otel_provider_name_default(self) -> None:
277271
"""Default OTEL provider name is azure.ai.inference."""
278272
assert AzureAIInferenceEmbeddingClient.OTEL_PROVIDER_NAME == "azure.ai.inference"
279273

280-
async def test_otel_provider_name_override(
281-
self, mock_text_client: AsyncMock, mock_image_client: AsyncMock
282-
) -> None:
274+
async def test_otel_provider_name_override(self, mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> None:
283275
"""OTEL provider name can be overridden."""
284276
client = AzureAIInferenceEmbeddingClient(
285277
model_id="test-model",

python/packages/bedrock/agent_framework_bedrock/_embedding_client.py

Lines changed: 69 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
Embedding,
1616
EmbeddingGenerationOptions,
1717
GeneratedEmbeddings,
18+
SecretString,
19+
UsageDetails,
20+
load_settings,
1821
)
19-
from agent_framework._settings import SecretString, load_settings
2022
from agent_framework.observability import EmbeddingTelemetryLayer
2123
from boto3.session import Session as Boto3Session
2224
from botocore.client import BaseClient
@@ -29,10 +31,19 @@
2931

3032

3133
logger = logging.getLogger("agent_framework.bedrock")
32-
3334
DEFAULT_REGION = "us-east-1"
3435

3536

37+
class BedrockEmbeddingSettings(TypedDict, total=False):
38+
"""Bedrock embedding settings."""
39+
40+
region: str | None
41+
embedding_model_id: str | None
42+
access_key: SecretString | None
43+
secret_key: SecretString | None
44+
session_token: SecretString | None
45+
46+
3647
class BedrockEmbeddingOptions(EmbeddingGenerationOptions, total=False):
3748
"""Bedrock-specific embedding options.
3849
@@ -61,16 +72,6 @@ class BedrockEmbeddingOptions(EmbeddingGenerationOptions, total=False):
6172
)
6273

6374

64-
class BedrockEmbeddingSettings(TypedDict, total=False):
65-
"""Bedrock embedding settings."""
66-
67-
region: str | None
68-
embedding_model_id: str | None
69-
access_key: SecretString | None
70-
secret_key: SecretString | None
71-
session_token: SecretString | None
72-
73-
7475
class RawBedrockEmbeddingClient(
7576
BaseEmbeddingClient[str, list[float], BedrockEmbeddingOptionsT],
7677
Generic[BedrockEmbeddingOptionsT],
@@ -80,8 +81,9 @@ class RawBedrockEmbeddingClient(
8081
Keyword Args:
8182
model_id: The Bedrock embedding model ID (e.g. "amazon.titan-embed-text-v2:0").
8283
Can also be set via environment variable BEDROCK_EMBEDDING_MODEL_ID.
83-
region: AWS region. Defaults to "us-east-1".
84-
Can also be set via environment variable BEDROCK_REGION.
84+
region: AWS region. Will try to load from BEDROCK_REGION env var,
85+
if not set, the regular Boto3 configuration/loading applies
86+
(which may include other env vars, config files, or instance metadata).
8587
access_key: AWS access key for manual credential injection.
8688
secret_key: AWS secret key paired with access_key.
8789
session_token: AWS session token for temporary credentials.
@@ -118,39 +120,33 @@ def __init__(
118120
env_file_path=env_file_path,
119121
env_file_encoding=env_file_encoding,
120122
)
121-
if not settings.get("region"):
122-
settings["region"] = DEFAULT_REGION
123+
resolved_region = settings.get("region") or DEFAULT_REGION
123124

124125
if client is None:
125-
session = boto3_session or self._create_session(settings)
126-
client = session.client(
126+
if not boto3_session:
127+
session_kwargs: dict[str, Any] = {}
128+
if region := settings.get("region"):
129+
session_kwargs["region_name"] = region
130+
if (access_key := settings.get("access_key")) and (secret_key := settings.get("secret_key")):
131+
session_kwargs["aws_access_key_id"] = access_key.get_secret_value() # type: ignore[union-attr]
132+
session_kwargs["aws_secret_access_key"] = secret_key.get_secret_value() # type: ignore[union-attr]
133+
if session_token := settings.get("session_token"):
134+
session_kwargs["aws_session_token"] = session_token.get_secret_value() # type: ignore[union-attr]
135+
boto3_session = Boto3Session(**session_kwargs)
136+
client = boto3_session.client(
127137
"bedrock-runtime",
128-
region_name=settings["region"],
138+
region_name=boto3_session.region_name or resolved_region,
129139
config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT),
130140
)
131141

132142
self._bedrock_client = client
133-
self.model_id = settings["embedding_model_id"]
134-
self.region = settings["region"]
143+
self.model_id = settings["embedding_model_id"] # type: ignore[assignment]
144+
self.region = resolved_region
135145
super().__init__(**kwargs)
136146

137-
@staticmethod
138-
def _create_session(settings: BedrockEmbeddingSettings) -> Boto3Session:
139-
"""Create a boto3 session from settings."""
140-
session_kwargs: dict[str, Any] = {"region_name": settings.get("region") or DEFAULT_REGION}
141-
if settings.get("access_key") and settings.get("secret_key"):
142-
session_kwargs["aws_access_key_id"] = settings["access_key"].get_secret_value() # type: ignore[union-attr]
143-
session_kwargs["aws_secret_access_key"] = settings["secret_key"].get_secret_value() # type: ignore[union-attr]
144-
if settings.get("session_token"):
145-
session_kwargs["aws_session_token"] = settings["session_token"].get_secret_value() # type: ignore[union-attr]
146-
return Boto3Session(**session_kwargs)
147-
148147
def service_url(self) -> str:
149148
"""Get the URL of the service."""
150-
meta = getattr(self._bedrock_client, "meta", None)
151-
if meta and hasattr(meta, "endpoint_url"):
152-
return str(meta.endpoint_url)
153-
return f"https://bedrock-runtime.{self.region}.amazonaws.com"
149+
return str(self._bedrock_client.meta.endpoint_url)
154150

155151
async def get_embeddings(
156152
self,
@@ -181,41 +177,50 @@ async def get_embeddings(
181177
if not model:
182178
raise ValueError("model_id is required")
183179

180+
embedding_results = await asyncio.gather(
181+
*(self._generate_embedding_for_text(opts, model, text) for text in values)
182+
)
184183
embeddings: list[Embedding[list[float]]] = []
185184
total_input_tokens = 0
185+
for embedding, input_tokens in embedding_results:
186+
embeddings.append(embedding)
187+
total_input_tokens += input_tokens
186188

187-
for text in values:
188-
body: dict[str, Any] = {"inputText": text}
189-
if dimensions := opts.get("dimensions"):
190-
body["dimensions"] = dimensions
191-
if (normalize := opts.get("normalize")) is not None:
192-
body["normalize"] = normalize
193-
194-
response = await asyncio.to_thread(
195-
self._bedrock_client.invoke_model,
196-
modelId=model,
197-
contentType="application/json",
198-
accept="application/json",
199-
body=json.dumps(body),
200-
)
201-
202-
response_body = json.loads(response["body"].read())
203-
vector = response_body["embedding"]
204-
embeddings.append(
205-
Embedding(
206-
vector=vector,
207-
dimensions=len(vector),
208-
model_id=model,
209-
)
210-
)
211-
total_input_tokens += response_body.get("inputTextTokenCount", 0)
212-
213-
usage_dict: dict[str, Any] | None = None
189+
usage_dict: UsageDetails | None = None
214190
if total_input_tokens > 0:
215-
usage_dict = {"prompt_tokens": total_input_tokens}
191+
usage_dict = {"input_token_count": total_input_tokens}
216192

217193
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
218194

195+
async def _generate_embedding_for_text(
196+
self,
197+
opts: dict[str, Any],
198+
model: str,
199+
text: str,
200+
) -> tuple[Embedding[list[float]], int]:
201+
body: dict[str, Any] = {"inputText": text}
202+
if dimensions := opts.get("dimensions"):
203+
body["dimensions"] = dimensions
204+
if (normalize := opts.get("normalize")) is not None:
205+
body["normalize"] = normalize
206+
207+
response = await asyncio.to_thread(
208+
self._bedrock_client.invoke_model,
209+
modelId=model,
210+
contentType="application/json",
211+
accept="application/json",
212+
body=json.dumps(body),
213+
)
214+
215+
response_body = json.loads(response["body"].read())
216+
embedding = Embedding(
217+
vector=response_body["embedding"],
218+
dimensions=len(response_body["embedding"]),
219+
model_id=model,
220+
)
221+
input_tokens = int(response_body.get("inputTextTokenCount", 0))
222+
return embedding, input_tokens
223+
219224

220225
class BedrockEmbeddingClient(
221226
EmbeddingTelemetryLayer[str, list[float], BedrockEmbeddingOptionsT],

python/packages/bedrock/tests/bedrock/test_bedrock_embedding_client.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@ def invoke_model(self, **kwargs: Any) -> dict[str, Any]:
2626
dimensions = body.get("dimensions", 3)
2727
return {
2828
"body": MagicMock(
29-
read=lambda: json.dumps(
30-
{
31-
"embedding": [0.1 * (i + 1) for i in range(dimensions)],
32-
"inputTextTokenCount": 5,
33-
}
34-
).encode()
29+
read=lambda: json.dumps({
30+
"embedding": [0.1 * (i + 1) for i in range(dimensions)],
31+
"inputTextTokenCount": 5,
32+
}).encode()
3533
),
3634
}
3735

@@ -73,14 +71,12 @@ async def test_bedrock_embedding_get_embeddings() -> None:
7371
assert len(result[0].vector) == 3
7472
assert len(result[1].vector) == 3
7573
assert result[0].model_id == "amazon.titan-embed-text-v2:0"
76-
assert result.usage == {"prompt_tokens": 10}
74+
assert result.usage == {"input_token_count": 10}
7775

7876
# Two calls since Titan processes one input at a time
7977
assert len(stub.calls) == 2
80-
body0 = json.loads(stub.calls[0]["body"])
81-
assert body0["inputText"] == "hello"
82-
body1 = json.loads(stub.calls[1]["body"])
83-
assert body1["inputText"] == "world"
78+
call_texts = {json.loads(call["body"])["inputText"] for call in stub.calls}
79+
assert call_texts == {"hello", "world"}
8480

8581

8682
async def test_bedrock_embedding_get_embeddings_empty_input() -> None:

0 commit comments

Comments
 (0)