Skip to content

Commit 606a714

Browse files
authored
fix: revert np.where to xarray.where when adding vars/ constraints (#575)
* fix: revert np.where to xarray.where * trigger
1 parent 97ed0c0 commit 606a714

1 file changed

Lines changed: 2 additions & 7 deletions

File tree

linopy/model.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,7 @@ def add_variables(
571571
self._xCounter += data.labels.size
572572

573573
if mask is not None:
574-
# Use numpy where for speed (38x faster than xarray where)
575-
data.labels.values = np.where(mask.values, data.labels.values, -1)
574+
data.labels.values = data.labels.where(mask, -1).values
576575

577576
data = data.assign_attrs(
578577
label_range=(start, end), name=name, binary=binary, integer=integer
@@ -750,9 +749,6 @@ def add_constraints(
750749
assert set(mask.dims).issubset(data.dims), (
751750
"Dimensions of mask not a subset of resulting labels dimensions."
752751
)
753-
# Broadcast mask to match data shape for correct numpy where behavior
754-
if mask.shape != data.labels.shape:
755-
mask, _ = xr.broadcast(mask, data.labels)
756752

757753
# Auto-mask based on null expressions or NaN RHS (use numpy for speed)
758754
if self.auto_mask:
@@ -785,8 +781,7 @@ def add_constraints(
785781
self._cCounter += data.labels.size
786782

787783
if mask is not None:
788-
# Use numpy where for speed (38x faster than xarray where)
789-
data.labels.values = np.where(mask.values, data.labels.values, -1)
784+
data.labels.values = data.labels.where(mask, -1).values
790785

791786
data = data.assign_attrs(label_range=(start, end), name=name)
792787

0 commit comments

Comments
 (0)