Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 58 additions & 35 deletions byaldi/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import srsly
import torch
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor
from pdf2image import convert_from_path
from PIL import Image

Expand All @@ -32,9 +32,12 @@ def __init__(
if isinstance(pretrained_model_name_or_path, Path):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)

if "colpali" not in pretrained_model_name_or_path.lower():
if (
"colpali" not in pretrained_model_name_or_path.lower()
and "colqwen2" not in pretrained_model_name_or_path.lower()
):
raise ValueError(
"This pre-release version of Byaldi only supports ColPali for now. Incorrect model name specified."
"This pre-release version of Byaldi only supports ColPali and ColQwen2 for now. Incorrect model name specified."
)

if verbose > 0:
Expand All @@ -48,9 +51,7 @@ def __init__(
device = (
device or "cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
else "mps" if torch.backends.mps.is_available() else "cpu"
)
self.index_name = index_name
self.verbose = verbose
Expand All @@ -64,26 +65,48 @@ def __init__(
self.doc_ids_to_file_names = {}
self.doc_ids = set()

self.model = ColPali.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
)
self.model = self.model.eval()

self.processor = cast(
ColPaliProcessor,
ColPaliProcessor.from_pretrained(
if "colpali" in pretrained_model_name_or_path.lower():
self.model = ColPali.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)
)
elif "colqwen2" in pretrained_model_name_or_path.lower():
self.model = ColQwen2.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
)
self.model = self.model.eval()

if "colpali" in pretrained_model_name_or_path.lower():
self.processor = cast(
ColPaliProcessor,
ColPaliProcessor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)
elif "colqwen2" in pretrained_model_name_or_path.lower():
self.processor = cast(
ColQwen2Processor,
ColQwen2Processor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)

self.device = device
if device != "cuda" and not (
Expand Down Expand Up @@ -240,9 +263,9 @@ def _export_index(self):
"model_name": self.model_name,
"full_document_collection": self.full_document_collection,
"highest_doc_id": self.highest_doc_id,
"resize_stored_images": True
if self.max_image_width and self.max_image_height
else False,
"resize_stored_images": (
True if self.max_image_width and self.max_image_height else False
),
"max_image_width": self.max_image_width,
"max_image_height": self.max_image_height,
"library_version": VERSION,
Expand Down Expand Up @@ -468,9 +491,9 @@ def _process_and_add_to_index(
with tempfile.TemporaryDirectory() as path:
images = convert_from_path(
item,
thread_count=os.cpu_count()-1,
thread_count=os.cpu_count() - 1,
output_folder=path,
paths_only=True
paths_only=True,
)
for i, image_path in enumerate(images):
image = Image.open(image_path)
Expand Down Expand Up @@ -613,9 +636,11 @@ def search(
page_num=int(doc_info["page_id"]),
score=float(scores[0][embed_id]),
metadata=self.doc_id_to_metadata.get(int(doc_info["doc_id"]), {}),
base64=self.collection.get(int(embed_id))
if return_base64_results
else None,
base64=(
self.collection.get(int(embed_id))
if return_base64_results
else None
),
)
query_results.append(result)

Expand Down Expand Up @@ -655,9 +680,7 @@ def encode_image(
# Process PDF
with tempfile.TemporaryDirectory() as path:
pdf_images = convert_from_path(
item,
thread_count=os.cpu_count()-1,
output_folder=path
item, thread_count=os.cpu_count() - 1, output_folder=path
)
images.extend(pdf_images)
elif item.lower().endswith(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ maintainers = [
]

dependencies = [
"colpali-engine>=0.3.0,<0.4.0",
"colpali-engine>=0.3.1,<0.4.0",
"ml-dtypes",
"mteb==1.6.35",
"ninja",
Expand Down
23 changes: 23 additions & 0 deletions tests/test_colqwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Generator

import pytest
from colpali_engine.models import ColQwen2
from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch

from byaldi import RAGMultiModalModel
from byaldi.colpali import ColPaliModel


@pytest.fixture(scope="module")
def colqwen_rag_model() -> Generator[RAGMultiModalModel, None, None]:
device = get_torch_device("auto")
print(f"Using device: {device}")
yield RAGMultiModalModel.from_pretrained("vidore/colqwen2-v0.1", device=device)
tear_down_torch()


@pytest.mark.slow
def test_load_colqwen_from_pretrained(colqwen_rag_model: RAGMultiModalModel):
assert isinstance(colqwen_rag_model, RAGMultiModalModel)
assert isinstance(colqwen_rag_model.model, ColPaliModel)
assert isinstance(colqwen_rag_model.model.model, ColQwen2)