Skip to content

Commit 74ef169

Browse files
authored
[embed] add compression parameter (#331)
* add compression parameter to embed * remove compress codebook * update changelog and toml
1 parent 085cb98 commit 74ef169

6 files changed

Lines changed: 13 additions & 102 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# Changelog
2+
## 4.32
3+
- [#331] (https://github.com/cohere-ai/cohere-python/pull/331)
4+
- Embed: add `compression` parameter for embed models
25

36
## 4.31
47
- [#324] (https://github.com/cohere-ai/cohere-python/pull/324)

cohere/client.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,7 @@ def embed(
393393
texts: List[str],
394394
model: Optional[str] = None,
395395
truncate: Optional[str] = None,
396-
compress: Optional[bool] = False,
397-
compression_codebook: Optional[str] = "default",
396+
compression: Optional[str] = None,
398397
input_type: Optional[str] = None,
399398
) -> Embeddings:
400399
"""Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings.
@@ -403,8 +402,7 @@ def embed(
403402
text (List[str]): A list of strings to embed.
404403
model (str): (Optional) The model ID to use for embedding the text.
405404
truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length.
406-
compress (bool): (Optional) Whether to compress the embeddings. When True, the compressed_embeddings will be returned as integers in the range [0, 255].
407-
compression_codebook (str): (Optional) The compression codebook to use for compressed embeddings. Defaults to "default".
405+
compression (str): (Optional) One of "int8" or "binary". The type of compression to use for the embeddings.
408406
input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed.
409407
"""
410408
responses = {
@@ -420,8 +418,7 @@ def embed(
420418
"model": model,
421419
"texts": texts_batch,
422420
"truncate": truncate,
423-
"compress": compress,
424-
"compression_codebook": compression_codebook,
421+
"compression": compression,
425422
"input_type": input_type,
426423
}
427424
)
@@ -1047,8 +1044,6 @@ def create_embed_job(
10471044
name: Optional[str] = None,
10481045
model: Optional[str] = None,
10491046
truncate: Optional[str] = None,
1050-
compress: Optional[bool] = None,
1051-
compression_codebook: Optional[str] = None,
10521047
text_field: Optional[str] = None,
10531048
) -> EmbedJob:
10541049
"""Create embed job.
@@ -1058,8 +1053,6 @@ def create_embed_job(
10581053
name (Optional[str], optional): The name of the embed job. Defaults to None.
10591054
model (Optional[str], optional): The model ID to use for embedding the text. Defaults to None.
10601055
truncate (Optional[str], optional): How the API handles text longer than the maximum token length. Defaults to None.
1061-
compress (Optional[bool], optional): Use embedding compression. Defaults to None.
1062-
compression_codebook (Optional[str], optional): Embedding compression codebook. Defaults to None.
10631056
text_field (Optional[str], optional): Name of the column containing text to embed. Defaults to None.
10641057
10651058
Returns:
@@ -1078,8 +1071,6 @@ def create_embed_job(
10781071
"name": name,
10791072
"model": model,
10801073
"truncate": truncate,
1081-
"compress": compress,
1082-
"compression_codebook": compression_codebook,
10831074
"text_field": text_field,
10841075
"output_format": "avro",
10851076
}

cohere/client_async.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,7 @@ async def embed(
271271
texts: List[str],
272272
model: Optional[str] = None,
273273
truncate: Optional[str] = None,
274-
compress: Optional[bool] = False,
275-
compression_codebook: Optional[str] = "default",
274+
compression: Optional[str] = None,
276275
input_type: Optional[str] = None,
277276
) -> Embeddings:
278277
"""Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings.
@@ -281,17 +280,15 @@ async def embed(
281280
text (List[str]): A list of strings to embed.
282281
model (str): (Optional) The model ID to use for embedding the text.
283282
truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length.
284-
compress (bool): (Optional) Whether to compress the embeddings. When True, the compressed_embeddings will be returned as integers in the range [0, 255].
285-
compression_codebook (str): (Optional) The compression codebook to use for compressed embeddings. Defaults to "default".
283+
compression (str): (Optional) One of "int8" or "binary". The type of compression to use for the embeddings.
286284
input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed.
287285
"""
288286
json_bodys = [
289287
dict(
290288
texts=texts[i : i + cohere.COHERE_EMBED_BATCH_SIZE],
291289
model=model,
292290
truncate=truncate,
293-
compress=compress,
294-
compression_codebook=compression_codebook,
291+
compression=compression,
295292
input_type=input_type,
296293
)
297294
for i in range(0, len(texts), cohere.COHERE_EMBED_BATCH_SIZE)
@@ -301,7 +298,9 @@ async def embed(
301298

302299
return Embeddings(
303300
embeddings=[e for res in responses for e in res["embeddings"]],
304-
compressed_embeddings=[e for res in responses for e in res["compressed_embeddings"]] if compress else None,
301+
compressed_embeddings=[e for res in responses for e in res["compressed_embeddings"]]
302+
if compression
303+
else None,
305304
meta=meta,
306305
)
307306

@@ -725,8 +724,6 @@ async def create_embed_job(
725724
name: Optional[str] = None,
726725
model: Optional[str] = None,
727726
truncate: Optional[str] = None,
728-
compress: Optional[bool] = None,
729-
compression_codebook: Optional[str] = None,
730727
text_field: Optional[str] = None,
731728
) -> AsyncEmbedJob:
732729
"""Create embed job.
@@ -736,8 +733,6 @@ async def create_embed_job(
736733
name (Optional[str], optional): The name of the embed job. Defaults to None.
737734
model (Optional[str], optional): The model ID to use for embedding the text. Defaults to None.
738735
truncate (Optional[str], optional): How the API handles text longer than the maximum token length. Defaults to None.
739-
compress (Optional[bool], optional): Use embedding compression. Defaults to None.
740-
compression_codebook (Optional[str], optional): Embedding compression codebook. Defaults to None.
741736
text_field (Optional[str], optional): Name of the column containing text to embed. Defaults to None.
742737
743738
Returns:
@@ -756,8 +751,6 @@ async def create_embed_job(
756751
"name": name,
757752
"model": model,
758753
"truncate": truncate,
759-
"compress": compress,
760-
"compression_codebook": compression_codebook,
761754
"text_field": text_field,
762755
"output_format": "avro",
763756
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "cohere"
3-
version = "4.31"
3+
version = "4.32"
44
description = ""
55
authors = ["Cohere"]
66
readme = "README.md"

tests/async/test_async_codebook.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

tests/sync/test_codebook.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

0 commit comments

Comments
 (0)