22
33from __future__ import annotations
44
5+ import io
6+ import json
57from dataclasses import dataclass
68from typing import Iterator , List , Optional , Union
79
@@ -44,50 +46,58 @@ def __init__(self, response: httpx.Response, batch_texts: Optional[List[str]] =
4446 def iter_embeddings (self ) -> Iterator [StreamedEmbedding ]:
4547 """
4648 Iterate over embeddings one at a time without loading all into memory.
47-
49+
4850 Yields:
4951 StreamedEmbedding objects as they are parsed from the response
5052 """
5153 if not IJSON_AVAILABLE :
5254 # Fallback to regular parsing if ijson not available
5355 yield from self ._iter_embeddings_fallback ()
5456 return
55-
57+
58+ # Buffer response content first to allow fallback if ijson fails
59+ # This prevents partial parsing issues where ijson yields some embeddings then fails
60+ response_content = self .response .content
61+
5662 try :
5763 # Use ijson for memory-efficient parsing
58- parser = ijson .parse (self . response . iter_bytes ( chunk_size = 65536 ))
64+ parser = ijson .parse (io . BytesIO ( response_content ))
5965 yield from self ._parse_with_ijson (parser )
6066 except Exception :
61- # If ijson parsing fails, fallback to regular parsing
62- yield from self ._iter_embeddings_fallback ()
67+ # If ijson parsing fails, fallback to regular parsing using buffered content
68+ data = json .loads (response_content )
69+ yield from self ._iter_embeddings_fallback_from_dict (data )
6370
6471 def _parse_with_ijson (self , parser ) -> Iterator [StreamedEmbedding ]:
6572 """Parse embeddings using ijson incremental parser."""
6673 current_path : List [str ] = []
6774 current_embedding = []
68- embedding_index = 0
75+ # Track text index separately per embedding type
76+ # When multiple types requested, each text gets multiple embeddings
77+ type_text_indices : dict = {}
6978 embedding_type = "float"
7079 response_type = None
7180 in_embeddings = False
72-
81+
7382 for prefix , event , value in parser :
7483 # Track current path
7584 if event == 'map_key' :
7685 if current_path and current_path [- 1 ] == 'embeddings' :
7786 # This is an embedding type key (float_, int8, etc.)
7887 embedding_type = value .rstrip ('_' )
79-
88+
8089 # Detect response type
8190 if prefix == 'response_type' :
8291 response_type = value
83-
92+
8493 # Handle embeddings based on response type
8594 if response_type == 'embeddings_floats' :
8695 # Simple float array format
8796 if prefix .startswith ('embeddings.item.item' ):
8897 current_embedding .append (value )
8998 elif prefix .startswith ('embeddings.item' ) and event == 'end_array' :
9099 # Complete embedding
100+ embedding_index = type_text_indices .get ('float' , 0 )
91101 text = self .batch_texts [embedding_index ] if embedding_index < len (self .batch_texts ) else None
92102 yield StreamedEmbedding (
93103 index = self .embeddings_yielded ,
@@ -96,18 +106,21 @@ def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
96106 text = text
97107 )
98108 self .embeddings_yielded += 1
99- embedding_index += 1
109+ type_text_indices [ 'float' ] = embedding_index + 1
100110 current_embedding = []
101-
111+
102112 elif response_type == 'embeddings_by_type' :
103113 # Complex format with multiple embedding types
104114 # Pattern: embeddings.<type>.item.item
115+ # ijson adds underscore to Python keywords like 'float'
105116 for emb_type in ['float_' , 'int8' , 'uint8' , 'binary' , 'ubinary' ]:
106117 type_name = emb_type .rstrip ('_' )
107118 if prefix .startswith (f'embeddings.{ emb_type } .item.item' ):
108119 current_embedding .append (value )
109120 elif prefix .startswith (f'embeddings.{ emb_type } .item' ) and event == 'end_array' :
110121 # Complete embedding of this type
122+ # Track index per type - same text can have multiple embedding types
123+ embedding_index = type_text_indices .get (type_name , 0 )
111124 text = self .batch_texts [embedding_index ] if embedding_index < len (self .batch_texts ) else None
112125 yield StreamedEmbedding (
113126 index = self .embeddings_yielded ,
@@ -116,11 +129,12 @@ def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
116129 text = text
117130 )
118131 self .embeddings_yielded += 1
119- embedding_index += 1
132+ type_text_indices [ type_name ] = embedding_index + 1
120133 current_embedding = []
121-
134+
122135 # Handle base64 embeddings (string format)
123136 if prefix .startswith ('embeddings.base64.item' ) and event == 'string' :
137+ embedding_index = type_text_indices .get ('base64' , 0 )
124138 text = self .batch_texts [embedding_index ] if embedding_index < len (self .batch_texts ) else None
125139 yield StreamedEmbedding (
126140 index = self .embeddings_yielded ,
@@ -129,7 +143,7 @@ def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
129143 text = text
130144 )
131145 self .embeddings_yielded += 1
132- embedding_index += 1
146+ type_text_indices [ 'base64' ] = embedding_index + 1
133147
134148 def _iter_embeddings_fallback (self ) -> Iterator [StreamedEmbedding ]:
135149 """Fallback method using regular JSON parsing."""
@@ -140,34 +154,40 @@ def _iter_embeddings_fallback(self) -> Iterator[StreamedEmbedding]:
140154 data = self .response ._response .json () # type: ignore
141155 else :
142156 raise ValueError ("Response object does not have a json() method" )
157+
158+ yield from self ._iter_embeddings_fallback_from_dict (data )
159+
160+ def _iter_embeddings_fallback_from_dict (self , data : dict ) -> Iterator [StreamedEmbedding ]:
161+ """Parse embeddings from a dictionary (used by fallback methods)."""
143162 response_type = data .get ('response_type' , '' )
144-
163+
145164 if response_type == 'embeddings_floats' :
146165 embeddings = data .get ('embeddings' , [])
147166 texts = data .get ('texts' , [])
148167 for i , embedding in enumerate (embeddings ):
149168 yield StreamedEmbedding (
150- index = i ,
169+ index = self . embeddings_yielded + i ,
151170 embedding = embedding ,
152171 embedding_type = 'float' ,
153172 text = texts [i ] if i < len (texts ) else None
154173 )
155-
174+
156175 elif response_type == 'embeddings_by_type' :
157176 embeddings_obj = data .get ('embeddings' , {})
158177 texts = data .get ('texts' , [])
159-
178+
160179 # Iterate through each embedding type
161180 for emb_type , embeddings_list in embeddings_obj .items ():
162181 type_name = emb_type .rstrip ('_' )
163182 if isinstance (embeddings_list , list ):
164183 for i , embedding in enumerate (embeddings_list ):
165184 yield StreamedEmbedding (
166- index = i ,
185+ index = self . embeddings_yielded ,
167186 embedding = embedding ,
168187 embedding_type = type_name ,
169188 text = texts [i ] if i < len (texts ) else None
170189 )
190+ self .embeddings_yielded += 1
171191
172192
173193def stream_embed_response (response : httpx .Response , texts : List [str ]) -> Iterator [StreamedEmbedding ]:
0 commit comments