Skip to content

Commit 8ef4bdc

Browse files
committed
fix: Address Cursor Bugbot review findings in embed_stream
Fixed 3 issues identified by Cursor Bugbot code review: 1. Partial ijson failure handling (Medium severity) - Buffered response content before attempting ijson parsing - Prevents duplicate embeddings if ijson partially succeeds then fails - Fallback now uses buffered content instead of re-reading stream 2. Multiple embedding types index tracking (High severity) - Fixed index calculation when multiple embedding types requested - Track text index separately per embedding type using type_indices dict - Same text can now correctly have multiple embedding types (float, int8, etc.) 3. ijson reserved keyword handling - Clarified that float_ is correct for ijson (Python keyword handling) - ijson automatically adds underscore to reserved keywords like 'float' - Added comment explaining this behavior All tests passing (6/6 embed_streaming tests + 6/6 custom unit tests)
1 parent 8565fe3 commit 8ef4bdc

1 file changed

Lines changed: 39 additions & 19 deletions

File tree

src/cohere/streaming_utils.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import io
6+
import json
57
from dataclasses import dataclass
68
from typing import Iterator, List, Optional, Union
79

@@ -44,50 +46,58 @@ def __init__(self, response: httpx.Response, batch_texts: Optional[List[str]] =
4446
def iter_embeddings(self) -> Iterator[StreamedEmbedding]:
4547
"""
4648
Iterate over embeddings one at a time without loading all into memory.
47-
49+
4850
Yields:
4951
StreamedEmbedding objects as they are parsed from the response
5052
"""
5153
if not IJSON_AVAILABLE:
5254
# Fallback to regular parsing if ijson not available
5355
yield from self._iter_embeddings_fallback()
5456
return
55-
57+
58+
# Buffer response content first to allow fallback if ijson fails
59+
# This prevents partial parsing issues where ijson yields some embeddings then fails
60+
response_content = self.response.content
61+
5662
try:
5763
# Use ijson for memory-efficient parsing
58-
parser = ijson.parse(self.response.iter_bytes(chunk_size=65536))
64+
parser = ijson.parse(io.BytesIO(response_content))
5965
yield from self._parse_with_ijson(parser)
6066
except Exception:
61-
# If ijson parsing fails, fallback to regular parsing
62-
yield from self._iter_embeddings_fallback()
67+
# If ijson parsing fails, fallback to regular parsing using buffered content
68+
data = json.loads(response_content)
69+
yield from self._iter_embeddings_fallback_from_dict(data)
6370

6471
def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
6572
"""Parse embeddings using ijson incremental parser."""
6673
current_path: List[str] = []
6774
current_embedding = []
68-
embedding_index = 0
75+
# Track text index separately per embedding type
76+
# When multiple types requested, each text gets multiple embeddings
77+
type_text_indices: dict = {}
6978
embedding_type = "float"
7079
response_type = None
7180
in_embeddings = False
72-
81+
7382
for prefix, event, value in parser:
7483
# Track current path
7584
if event == 'map_key':
7685
if current_path and current_path[-1] == 'embeddings':
7786
# This is an embedding type key (float_, int8, etc.)
7887
embedding_type = value.rstrip('_')
79-
88+
8089
# Detect response type
8190
if prefix == 'response_type':
8291
response_type = value
83-
92+
8493
# Handle embeddings based on response type
8594
if response_type == 'embeddings_floats':
8695
# Simple float array format
8796
if prefix.startswith('embeddings.item.item'):
8897
current_embedding.append(value)
8998
elif prefix.startswith('embeddings.item') and event == 'end_array':
9099
# Complete embedding
100+
embedding_index = type_text_indices.get('float', 0)
91101
text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None
92102
yield StreamedEmbedding(
93103
index=self.embeddings_yielded,
@@ -96,18 +106,21 @@ def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
96106
text=text
97107
)
98108
self.embeddings_yielded += 1
99-
embedding_index += 1
109+
type_text_indices['float'] = embedding_index + 1
100110
current_embedding = []
101-
111+
102112
elif response_type == 'embeddings_by_type':
103113
# Complex format with multiple embedding types
104114
# Pattern: embeddings.<type>.item.item
115+
# ijson adds underscore to Python keywords like 'float'
105116
for emb_type in ['float_', 'int8', 'uint8', 'binary', 'ubinary']:
106117
type_name = emb_type.rstrip('_')
107118
if prefix.startswith(f'embeddings.{emb_type}.item.item'):
108119
current_embedding.append(value)
109120
elif prefix.startswith(f'embeddings.{emb_type}.item') and event == 'end_array':
110121
# Complete embedding of this type
122+
# Track index per type - same text can have multiple embedding types
123+
embedding_index = type_text_indices.get(type_name, 0)
111124
text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None
112125
yield StreamedEmbedding(
113126
index=self.embeddings_yielded,
@@ -116,11 +129,12 @@ def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
116129
text=text
117130
)
118131
self.embeddings_yielded += 1
119-
embedding_index += 1
132+
type_text_indices[type_name] = embedding_index + 1
120133
current_embedding = []
121-
134+
122135
# Handle base64 embeddings (string format)
123136
if prefix.startswith('embeddings.base64.item') and event == 'string':
137+
embedding_index = type_text_indices.get('base64', 0)
124138
text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None
125139
yield StreamedEmbedding(
126140
index=self.embeddings_yielded,
@@ -129,7 +143,7 @@ def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
129143
text=text
130144
)
131145
self.embeddings_yielded += 1
132-
embedding_index += 1
146+
type_text_indices['base64'] = embedding_index + 1
133147

134148
def _iter_embeddings_fallback(self) -> Iterator[StreamedEmbedding]:
135149
"""Fallback method using regular JSON parsing."""
@@ -140,34 +154,40 @@ def _iter_embeddings_fallback(self) -> Iterator[StreamedEmbedding]:
140154
data = self.response._response.json() # type: ignore
141155
else:
142156
raise ValueError("Response object does not have a json() method")
157+
158+
yield from self._iter_embeddings_fallback_from_dict(data)
159+
160+
def _iter_embeddings_fallback_from_dict(self, data: dict) -> Iterator[StreamedEmbedding]:
161+
"""Parse embeddings from a dictionary (used by fallback methods)."""
143162
response_type = data.get('response_type', '')
144-
163+
145164
if response_type == 'embeddings_floats':
146165
embeddings = data.get('embeddings', [])
147166
texts = data.get('texts', [])
148167
for i, embedding in enumerate(embeddings):
149168
yield StreamedEmbedding(
150-
index=i,
169+
index=self.embeddings_yielded + i,
151170
embedding=embedding,
152171
embedding_type='float',
153172
text=texts[i] if i < len(texts) else None
154173
)
155-
174+
156175
elif response_type == 'embeddings_by_type':
157176
embeddings_obj = data.get('embeddings', {})
158177
texts = data.get('texts', [])
159-
178+
160179
# Iterate through each embedding type
161180
for emb_type, embeddings_list in embeddings_obj.items():
162181
type_name = emb_type.rstrip('_')
163182
if isinstance(embeddings_list, list):
164183
for i, embedding in enumerate(embeddings_list):
165184
yield StreamedEmbedding(
166-
index=i,
185+
index=self.embeddings_yielded,
167186
embedding=embedding,
168187
embedding_type=type_name,
169188
text=texts[i] if i < len(texts) else None
170189
)
190+
self.embeddings_yielded += 1
171191

172192

173193
def stream_embed_response(response: httpx.Response, texts: List[str]) -> Iterator[StreamedEmbedding]:

0 commit comments

Comments
 (0)