Skip to content

Commit 2906f2d

Browse files
Support event type when streaming chat. Remove deprecated parameters. (#287)
* Initial additions to chat functionality. Checks for streaming event type * Update version * Optional conversation id
1 parent 07e6684 commit 2906f2d

7 files changed

Lines changed: 107 additions & 88 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## 4.21
4+
- [#287] (https://github.com/cohere-ai/cohere-python/pull/287)
5+
- Remove deprecated chat "query" parameter including inside chat_history parameter
6+
- Support event-type for chat streaming
7+
38
## 4.20.2
49
- [#284] (https://github.com/cohere-ai/cohere-python/pull/284)
510
- Rename dataset urls to download_urls

cohere/client.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def generate(
227227
def chat(
228228
self,
229229
message: Optional[str] = None,
230-
query: Optional[str] = None,
231230
conversation_id: Optional[str] = "",
232231
model: Optional[str] = None,
233232
return_chatlog: Optional[bool] = False,
@@ -246,30 +245,33 @@ def chat(
246245
"""Returns a Chat object with the query reply.
247246
248247
Args:
249-
query (str): Deprecated. Use message instead.
250248
message (str): The message to send to the chatbot.
251-
conversation_id (str): (Optional) The conversation id to continue the conversation.
252-
model (str): (Optional) The model to use for generating the next reply.
253-
return_chatlog (bool): (Optional) Whether to return the chatlog.
254-
return_prompt (bool): (Optional) Whether to return the prompt.
255-
return_preamble (bool): (Optional) Whether to return the preamble.
256-
chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state.
257-
preamble_override (str): (Optional) A string to override the preamble.
258-
user_name (str): (Optional) A string to override the username.
259-
temperature (float): (Optional) The temperature to use for the next reply. The higher the temperature, the more random the reply.
260-
max_tokens (int): (Optional) The max tokens generated for the next reply.
249+
261250
stream (bool): Return streaming tokens.
251+
conversation_id (str): (Optional) To store a conversation then create a conversation id and use it for every related request.
252+
253+
preamble_override (str): (Optional) A string to override the preamble.
254+
chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state.
255+
256+
model (str): (Optional) The model to use for generating the response.
257+
temperature (float): (Optional) The temperature to use for the response. The higher the temperature, the more random the response.
262258
p (float): (Optional) The nucleus sampling probability.
263259
k (float): (Optional) The top-k sampling probability.
264260
logit_bias (Dict[int, float]): (Optional) A dictionary of logit bias values to use for the next reply.
261+
max_tokens (int): (Optional) The max tokens generated for the next reply.
262+
263+
return_chatlog (bool): (Optional) Whether to return the chatlog.
264+
return_prompt (bool): (Optional) Whether to return the prompt.
265+
return_preamble (bool): (Optional) Whether to return the preamble.
266+
267+
user_name (str): (Optional) A string to override the username.
265268
Returns:
266269
a Chat object if stream=False, or a StreamingChat object if stream=True
267270
268271
Examples:
269272
A simple chat message:
270273
>>> res = co.chat(message="Hey! How are you doing today?")
271274
>>> print(res.text)
272-
>>> print(res.conversation_id)
273275
Continuing a session using a specific model:
274276
>>> res = co.chat(
275277
>>> message="Hey! How are you doing today?",
@@ -295,25 +297,6 @@ def chat(
295297
>>> print(res.text)
296298
>>> print(res.prompt)
297299
"""
298-
if chat_history is not None:
299-
should_warn = True
300-
for entry in chat_history:
301-
if "text" in entry:
302-
entry["message"] = entry["text"]
303-
304-
if "text" in entry and should_warn:
305-
logger.warning(
306-
"The 'text' parameter is deprecated and will be removed in a future version of this function. "
307-
+ "Use 'message' instead.",
308-
)
309-
should_warn = False
310-
311-
if query is not None:
312-
logger.warning(
313-
"The chat_history 'text' key is deprecated and will be removed in a future version of this function. "
314-
+ "Use 'message' instead.",
315-
)
316-
message = query
317300

318301
json_body = {
319302
"message": message,

cohere/client_async.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ async def generate(
209209
async def chat(
210210
self,
211211
message: Optional[str] = None,
212-
query: Optional[str] = None,
213212
conversation_id: Optional[str] = "",
214213
model: Optional[str] = None,
215214
return_chatlog: Optional[bool] = False,
@@ -225,28 +224,8 @@ async def chat(
225224
k: Optional[float] = None,
226225
logit_bias: Optional[Dict[int, float]] = None,
227226
) -> Union[AsyncChat, StreamingChat]:
228-
if chat_history is not None:
229-
should_warn = True
230-
for entry in chat_history:
231-
if "text" in entry:
232-
entry["message"] = entry["text"]
233-
234-
if "text" in entry and should_warn:
235-
logger.warning(
236-
"The 'text' parameter is deprecated and will be removed in a future version of this function. "
237-
+ "Use 'message' instead.",
238-
)
239-
should_warn = False
240-
241-
if query is None and message is None:
242-
raise CohereError("Either 'query' or 'message' must be provided.")
243-
244-
if query is not None:
245-
logger.warning(
246-
"The 'query' parameter is deprecated and will be removed in a future version of this function. "
247-
+ "Use 'message' instead.",
248-
)
249-
message = query
227+
if message is None:
228+
raise CohereError("'message' must be provided.")
250229

251230
json_body = {
252231
"message": message,

cohere/responses/chat.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, Generator, List, NamedTuple, Optional
2+
from typing import Any, Dict, Generator, List, Optional
33

44
import requests
55

@@ -25,10 +25,9 @@ def __init__(
2525
super().__init__(**kwargs)
2626
self.response_id = response_id
2727
self.generation_id = generation_id
28-
self.query = message # to be deprecated
2928
self.message = message
3029
self.text = text
31-
self.conversation_id = conversation_id
30+
self.conversation_id = conversation_id # optional
3231
self.prompt = prompt # optional
3332
self.chatlog = chatlog # optional
3433
self.preamble = preamble # optional
@@ -47,7 +46,7 @@ def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat":
4746
text=response.get("text"),
4847
prompt=response.get("prompt"), # optional
4948
chatlog=response.get("chatlog"), # optional
50-
preamble=response.get("preamble"), # option
49+
preamble=response.get("preamble"), # optional
5150
client=client,
5251
token_count=response.get("token_count"),
5352
meta=response.get("meta"),
@@ -76,7 +75,38 @@ async def respond(self, response: str, max_tokens: int = None) -> "AsyncChat":
7675
)
7776

7877

79-
StreamingText = NamedTuple("StreamingText", [("index", Optional[int]), ("text", str), ("is_finished", bool)])
78+
class StreamResponse(CohereObject):
79+
def __init__(
80+
self,
81+
is_finished: bool,
82+
index: Optional[int],
83+
**kwargs,
84+
) -> None:
85+
super().__init__(**kwargs)
86+
self.is_finished = is_finished
87+
self.index = index
88+
89+
90+
class StreamStart(StreamResponse):
91+
def __init__(
92+
self,
93+
generation_id: str,
94+
conversation_id: Optional[str],
95+
**kwargs,
96+
) -> None:
97+
super().__init__(**kwargs)
98+
self.generation_id = generation_id
99+
self.conversation_id = conversation_id
100+
101+
102+
class StreamTextGeneration(StreamResponse):
103+
def __init__(
104+
self,
105+
text: str,
106+
**kwargs,
107+
) -> None:
108+
super().__init__(**kwargs)
109+
self.text = text
80110

81111

82112
class StreamingChat(CohereObject):
@@ -85,34 +115,47 @@ def __init__(self, response):
85115
self.texts = []
86116
self.response_id = None
87117
self.conversation_id = None
118+
self.generation_id = None
88119
self.preamble = None
89120
self.prompt = None
90121
self.chatlog = None
91122
self.finish_reason = None
123+
self.token_count = None
124+
self.meta = None
92125

93126
def _make_response_item(self, index, line) -> Any:
94127
streaming_item = json.loads(line)
95-
is_finished = streaming_item.get("is_finished")
96-
text = streaming_item.get("text")
97-
98-
if not is_finished:
99-
return StreamingText(text=text, is_finished=is_finished, index=index)
100-
101-
response = streaming_item.get("response")
102-
103-
if response is None:
128+
event_type = streaming_item.get("event_type")
129+
130+
if event_type == "stream-start":
131+
self.conversation_id = streaming_item.get("conversation_id")
132+
self.generation_id = streaming_item.get("generation_id")
133+
return StreamStart(
134+
conversation_id=self.conversation_id, generation_id=self.generation_id, is_finished=False, index=index
135+
)
136+
elif event_type == "text-generation":
137+
text = streaming_item.get("text")
138+
return StreamTextGeneration(text=text, is_finished=False, index=index)
139+
elif event_type == "stream-end":
140+
response = streaming_item.get("response")
141+
self.finish_reason = streaming_item.get("finish_reason")
142+
143+
if response is None:
144+
return None
145+
146+
self.response_id = response.get("response_id")
147+
self.conversation_id = response.get("conversation_id")
148+
self.texts = [response.get("text")]
149+
self.generation_id = response.get("generation_id")
150+
self.preamble = response.get("preamble")
151+
self.prompt = response.get("prompt")
152+
self.chatlog = response.get("chatlog")
153+
self.token_count = response.get("token_count")
154+
self.meta = response.get("meta")
104155
return None
105-
106-
self.response_id = response.get("response_id")
107-
self.conversation_id = response.get("conversation_id")
108-
self.preamble = response.get("preamble")
109-
self.prompt = response.get("prompt")
110-
self.chatlog = response.get("chatlog")
111-
self.finish_reason = streaming_item.get("finish_reason")
112-
self.texts = [response.get("text")]
113156
return None
114157

115-
def __iter__(self) -> Generator[StreamingText, None, None]:
158+
def __iter__(self) -> Generator[StreamResponse, None, None]:
116159
if not isinstance(self.response, requests.Response):
117160
raise ValueError("For AsyncClient, use `async for` to iterate through the `StreamingChat`")
118161

@@ -121,7 +164,7 @@ def __iter__(self) -> Generator[StreamingText, None, None]:
121164
if item is not None:
122165
yield item
123166

124-
async def __aiter__(self) -> Generator[StreamingText, None, None]:
167+
async def __aiter__(self) -> Generator[StreamResponse, None, None]:
125168
index = 0
126169
async for line in self.response.content:
127170
item = self._make_response_item(index, line)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "cohere"
3-
version = "4.20.2"
3+
version = "4.21"
44
description = ""
55
authors = ["Cohere"]
66
readme = "README.md"

tests/async/test_async_chat.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22

3+
import cohere
4+
35

46
@pytest.mark.asyncio
57
async def test_async_multi_replies(async_client):
@@ -36,12 +38,17 @@ async def test_async_chat_stream(async_client):
3638
expected_index = 0
3739
expected_text = ""
3840
async for token in res:
39-
if token.text:
41+
if isinstance(token, cohere.responses.chat.StreamStart):
42+
assert token.generation_id is not None
43+
assert not token.is_finished
44+
elif isinstance(token, cohere.responses.chat.StreamTextGeneration):
4045
assert isinstance(token.text, str)
4146
assert len(token.text) > 0
42-
assert token.index == expected_index
43-
4447
expected_text += token.text
48+
assert not token.is_finished
49+
50+
assert isinstance(token.index, int)
51+
assert token.index == expected_index
4552
expected_index += 1
4653

4754
assert res.texts == [expected_text]

tests/sync/test_chat.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,16 @@ def test_stream(self):
9191
expected_index = 0
9292
expected_text = ""
9393
for token in prediction:
94-
if token.text:
94+
if isinstance(token, cohere.responses.chat.StreamStart):
95+
self.assertIsNotNone(token.generation_id)
96+
self.assertFalse(token.is_finished)
97+
elif isinstance(token, cohere.responses.chat.StreamTextGeneration):
9598
self.assertIsInstance(token.text, str)
9699
self.assertGreater(len(token.text), 0)
97-
98-
self.assertIsInstance(token.index, int)
99-
self.assertEqual(token.index, expected_index)
100-
101100
expected_text += token.text
101+
self.assertFalse(token.is_finished)
102+
self.assertIsInstance(token.index, int)
103+
self.assertEqual(token.index, expected_index)
102104
expected_index += 1
103105

104106
self.assertEqual(prediction.texts, [expected_text])

0 commit comments

Comments
 (0)