Skip to content

Commit 691dbff

Browse files
feat: add beartype runtime type validation for numpy arrays and numeric inputs
Co-authored-by: aider (openrouter/anthropic/claude-sonnet-4) <aider@aider.chat>
1 parent f049846 commit 691dbff

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

wdoc/utils/customs/binary_faiss_vectorstore.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
Tuple,
2525
Union,
2626
TypeAlias,
27+
Annotated,
2728
)
29+
from beartype.vale import IsAttr, IsEqual, Is
2830

2931
import numpy as np
3032
import numpy.typing as npt
@@ -47,8 +49,21 @@
4749

4850
NDArray = npt.NDArray # required for beartype
4951
ArrayLike = npt.ArrayLike # required for beartype
50-
UInt8Array: TypeAlias = NDArray[np.uint8] # 2D
51-
UInt8Vector: TypeAlias = NDArray[np.uint8] # 1D
52+
UInt8Array: TypeAlias = Annotated[
53+
NDArray[np.uint8],
54+
IsAttr["ndim", IsEqual[2]] & IsAttr["dtype", IsEqual[np.dtype(np.uint8)]],
55+
] # 2D binary array with uint8 dtype
56+
UInt8Vector: TypeAlias = Annotated[
57+
NDArray[np.uint8],
58+
IsAttr["ndim", IsEqual[1]] & IsAttr["dtype", IsEqual[np.dtype(np.uint8)]],
59+
] # 1D binary vector with uint8 dtype
60+
61+
# Type alias for numeric input arrays that will be converted to binary
62+
NumericArrayLike: TypeAlias = Annotated[
63+
Union[ArrayLike, List[float], List[List[float]]],
64+
Is[lambda x: len(np.array(x).shape) <= 2] # Max 2D
65+
& Is[lambda x: np.issubdtype(np.array(x).dtype, np.number)], # Numeric values only
66+
]
5267

5368

5469
class CompressedFAISS(FAISS):
@@ -317,7 +332,7 @@ async def new_aembedding_function(self, texts: List[str]) -> UInt8Array:
317332

318333
@staticmethod
319334
def _vec_to_binary(
320-
vectors: Union[ArrayLike, List[float], List[List[float]]],
335+
vectors: NumericArrayLike,
321336
) -> UInt8Array:
322337
"""Convert vectors to binary format using global zero threshold.
323338

0 commit comments

Comments
 (0)