1515 Embedding ,
1616 EmbeddingGenerationOptions ,
1717 GeneratedEmbeddings ,
18+ SecretString ,
19+ UsageDetails ,
20+ load_settings ,
1821)
19- from agent_framework ._settings import SecretString , load_settings
2022from agent_framework .observability import EmbeddingTelemetryLayer
2123from boto3 .session import Session as Boto3Session
2224from botocore .client import BaseClient
2931
3032
3133logger = logging .getLogger ("agent_framework.bedrock" )
32-
3334DEFAULT_REGION = "us-east-1"
3435
3536
37+ class BedrockEmbeddingSettings (TypedDict , total = False ):
38+ """Bedrock embedding settings."""
39+
40+ region : str | None
41+ embedding_model_id : str | None
42+ access_key : SecretString | None
43+ secret_key : SecretString | None
44+ session_token : SecretString | None
45+
46+
3647class BedrockEmbeddingOptions (EmbeddingGenerationOptions , total = False ):
3748 """Bedrock-specific embedding options.
3849
@@ -61,16 +72,6 @@ class BedrockEmbeddingOptions(EmbeddingGenerationOptions, total=False):
6172)
6273
6374
64- class BedrockEmbeddingSettings (TypedDict , total = False ):
65- """Bedrock embedding settings."""
66-
67- region : str | None
68- embedding_model_id : str | None
69- access_key : SecretString | None
70- secret_key : SecretString | None
71- session_token : SecretString | None
72-
73-
7475class RawBedrockEmbeddingClient (
7576 BaseEmbeddingClient [str , list [float ], BedrockEmbeddingOptionsT ],
7677 Generic [BedrockEmbeddingOptionsT ],
@@ -80,8 +81,9 @@ class RawBedrockEmbeddingClient(
8081 Keyword Args:
8182 model_id: The Bedrock embedding model ID (e.g. "amazon.titan-embed-text-v2:0").
8283 Can also be set via environment variable BEDROCK_EMBEDDING_MODEL_ID.
83- region: AWS region. Defaults to "us-east-1".
84- Can also be set via environment variable BEDROCK_REGION.
84+ region: AWS region. Will try to load from BEDROCK_REGION env var,
85+ if not set, the regular Boto3 configuration/loading applies
86+ (which may include other env vars, config files, or instance metadata).
8587 access_key: AWS access key for manual credential injection.
8688 secret_key: AWS secret key paired with access_key.
8789 session_token: AWS session token for temporary credentials.
@@ -118,39 +120,33 @@ def __init__(
118120 env_file_path = env_file_path ,
119121 env_file_encoding = env_file_encoding ,
120122 )
121- if not settings .get ("region" ):
122- settings ["region" ] = DEFAULT_REGION
123+ resolved_region = settings .get ("region" ) or DEFAULT_REGION
123124
124125 if client is None :
125- session = boto3_session or self ._create_session (settings )
126- client = session .client (
126+ if not boto3_session :
127+ session_kwargs : dict [str , Any ] = {}
128+ if region := settings .get ("region" ):
129+ session_kwargs ["region_name" ] = region
130+ if (access_key := settings .get ("access_key" )) and (secret_key := settings .get ("secret_key" )):
131+ session_kwargs ["aws_access_key_id" ] = access_key .get_secret_value () # type: ignore[union-attr]
132+ session_kwargs ["aws_secret_access_key" ] = secret_key .get_secret_value () # type: ignore[union-attr]
133+ if session_token := settings .get ("session_token" ):
134+ session_kwargs ["aws_session_token" ] = session_token .get_secret_value () # type: ignore[union-attr]
135+ boto3_session = Boto3Session (** session_kwargs )
136+ client = boto3_session .client (
127137 "bedrock-runtime" ,
128- region_name = settings [ "region" ] ,
138+ region_name = boto3_session . region_name or resolved_region ,
129139 config = BotoConfig (user_agent_extra = AGENT_FRAMEWORK_USER_AGENT ),
130140 )
131141
132142 self ._bedrock_client = client
133- self .model_id = settings ["embedding_model_id" ]
134- self .region = settings [ "region" ]
143+ self .model_id = settings ["embedding_model_id" ] # type: ignore[assignment]
144+ self .region = resolved_region
135145 super ().__init__ (** kwargs )
136146
137- @staticmethod
138- def _create_session (settings : BedrockEmbeddingSettings ) -> Boto3Session :
139- """Create a boto3 session from settings."""
140- session_kwargs : dict [str , Any ] = {"region_name" : settings .get ("region" ) or DEFAULT_REGION }
141- if settings .get ("access_key" ) and settings .get ("secret_key" ):
142- session_kwargs ["aws_access_key_id" ] = settings ["access_key" ].get_secret_value () # type: ignore[union-attr]
143- session_kwargs ["aws_secret_access_key" ] = settings ["secret_key" ].get_secret_value () # type: ignore[union-attr]
144- if settings .get ("session_token" ):
145- session_kwargs ["aws_session_token" ] = settings ["session_token" ].get_secret_value () # type: ignore[union-attr]
146- return Boto3Session (** session_kwargs )
147-
148147 def service_url (self ) -> str :
149148 """Get the URL of the service."""
150- meta = getattr (self ._bedrock_client , "meta" , None )
151- if meta and hasattr (meta , "endpoint_url" ):
152- return str (meta .endpoint_url )
153- return f"https://bedrock-runtime.{ self .region } .amazonaws.com"
149+ return str (self ._bedrock_client .meta .endpoint_url )
154150
155151 async def get_embeddings (
156152 self ,
@@ -181,41 +177,50 @@ async def get_embeddings(
181177 if not model :
182178 raise ValueError ("model_id is required" )
183179
180+ embedding_results = await asyncio .gather (
181+ * (self ._generate_embedding_for_text (opts , model , text ) for text in values )
182+ )
184183 embeddings : list [Embedding [list [float ]]] = []
185184 total_input_tokens = 0
185+ for embedding , input_tokens in embedding_results :
186+ embeddings .append (embedding )
187+ total_input_tokens += input_tokens
186188
187- for text in values :
188- body : dict [str , Any ] = {"inputText" : text }
189- if dimensions := opts .get ("dimensions" ):
190- body ["dimensions" ] = dimensions
191- if (normalize := opts .get ("normalize" )) is not None :
192- body ["normalize" ] = normalize
193-
194- response = await asyncio .to_thread (
195- self ._bedrock_client .invoke_model ,
196- modelId = model ,
197- contentType = "application/json" ,
198- accept = "application/json" ,
199- body = json .dumps (body ),
200- )
201-
202- response_body = json .loads (response ["body" ].read ())
203- vector = response_body ["embedding" ]
204- embeddings .append (
205- Embedding (
206- vector = vector ,
207- dimensions = len (vector ),
208- model_id = model ,
209- )
210- )
211- total_input_tokens += response_body .get ("inputTextTokenCount" , 0 )
212-
213- usage_dict : dict [str , Any ] | None = None
189+ usage_dict : UsageDetails | None = None
214190 if total_input_tokens > 0 :
215- usage_dict = {"prompt_tokens " : total_input_tokens }
191+ usage_dict = {"input_token_count " : total_input_tokens }
216192
217193 return GeneratedEmbeddings (embeddings , options = options , usage = usage_dict )
218194
195+ async def _generate_embedding_for_text (
196+ self ,
197+ opts : dict [str , Any ],
198+ model : str ,
199+ text : str ,
200+ ) -> tuple [Embedding [list [float ]], int ]:
201+ body : dict [str , Any ] = {"inputText" : text }
202+ if dimensions := opts .get ("dimensions" ):
203+ body ["dimensions" ] = dimensions
204+ if (normalize := opts .get ("normalize" )) is not None :
205+ body ["normalize" ] = normalize
206+
207+ response = await asyncio .to_thread (
208+ self ._bedrock_client .invoke_model ,
209+ modelId = model ,
210+ contentType = "application/json" ,
211+ accept = "application/json" ,
212+ body = json .dumps (body ),
213+ )
214+
215+ response_body = json .loads (response ["body" ].read ())
216+ embedding = Embedding (
217+ vector = response_body ["embedding" ],
218+ dimensions = len (response_body ["embedding" ]),
219+ model_id = model ,
220+ )
221+ input_tokens = int (response_body .get ("inputTextTokenCount" , 0 ))
222+ return embedding , input_tokens
223+
219224
220225class BedrockEmbeddingClient (
221226 EmbeddingTelemetryLayer [str , list [float ], BedrockEmbeddingOptionsT ],
0 commit comments