Skip to content

Commit 1395fb5

Browse files
Fixes bug with extent of rotated data (#373)
1 parent b229099 commit 1395fb5

File tree

3 files changed

+276
-112
lines changed

3 files changed

+276
-112
lines changed

src/spatialdata/_core/data_extent.py

Lines changed: 112 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from collections import defaultdict
44
from functools import singledispatch
5-
from typing import Union
65

76
import numpy as np
87
import pandas as pd
@@ -15,7 +14,6 @@
1514

1615
from spatialdata._core.operations.transform import transform
1716
from spatialdata._core.spatialdata import SpatialData
18-
from spatialdata._types import ArrayLike
1917
from spatialdata.models import get_axes_names
2018
from spatialdata.models._utils import SpatialElement
2119
from spatialdata.models.models import PointsModel
@@ -86,32 +84,45 @@ def _get_extent_of_polygons_multipolygons(
8684
return extent
8785

8886

87+
def _get_extent_of_points(e: DaskDataFrame) -> BoundingBoxDescription:
88+
axes = get_axes_names(e)
89+
min_coordinates = np.array([e[ax].min().compute() for ax in axes])
90+
max_coordinates = np.array([e[ax].max().compute() for ax in axes])
91+
extent = {}
92+
for i, ax in enumerate(axes):
93+
extent[ax] = (min_coordinates[i], max_coordinates[i])
94+
return extent
95+
96+
8997
def _get_extent_of_data_array(e: DataArray, coordinate_system: str) -> BoundingBoxDescription:
9098
# lightweight conversion to SpatialImage just to fix the type of the single-dispatch
9199
_check_element_has_coordinate_system(element=SpatialImage(e), coordinate_system=coordinate_system)
92100
# also here
93101
data_axes = get_axes_names(SpatialImage(e))
94-
min_coordinates = []
95-
max_coordinates = []
96-
axes = []
102+
extent: BoundingBoxDescription = {}
97103
for ax in ["z", "y", "x"]:
98104
if ax in data_axes:
99105
i = data_axes.index(ax)
100-
axes.append(ax)
101-
min_coordinates.append(0)
102-
max_coordinates.append(e.shape[i])
106+
extent[ax] = (0, e.shape[i])
103107
return _compute_extent_in_coordinate_system(
104108
# and here
105109
element=SpatialImage(e),
106110
coordinate_system=coordinate_system,
107-
min_coordinates=np.array(min_coordinates),
108-
max_coordinates=np.array(max_coordinates),
109-
axes=tuple(axes),
111+
extent=extent,
110112
)
111113

112114

113115
@singledispatch
114-
def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global") -> BoundingBoxDescription:
116+
def get_extent(
117+
e: SpatialData | SpatialElement,
118+
coordinate_system: str = "global",
119+
exact: bool = True,
120+
has_images: bool = True,
121+
has_labels: bool = True,
122+
has_points: bool = True,
123+
has_shapes: bool = True,
124+
elements: list[str] | None = None,
125+
) -> BoundingBoxDescription:
115126
"""
116127
Get the extent (bounding box) of a SpatialData object or a SpatialElement.
117128
@@ -128,6 +139,37 @@ def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global
128139
The maximum coordinate of the bounding box.
129140
axes
130141
The names of the dimensions of the bounding box
142+
exact
143+
If True, the extent is computed exactly. If False, an approximation faster to compute is given. The
144+
approximation is guaranteed to contain all the data, see notes for details.
145+
has_images
146+
If True, images are included in the computation of the extent.
147+
has_labels
148+
If True, labels are included in the computation of the extent.
149+
has_points
150+
If True, points are included in the computation of the extent.
151+
has_shapes
152+
If True, shapes are included in the computation of the extent.
153+
elements
154+
If not None, only the elements with the given names are included in the computation of the extent.
155+
156+
Notes
157+
-----
158+
The extent of a SpatialData object is the extent of the union of the extents of all its elements. The extent of a
159+
SpatialElement is the extent of the element in the coordinate system specified by the argument `coordinate_system`.
160+
161+
If `exact` is False, first the extent of the SpatialElement before any transformation is computed. Then, the extent
162+
is transformed to the target coordinate system. This is faster than computing the extent after the transformation,
163+
since the transformation is applied to extent of the untransformed data, as opposed to transforming the data and
164+
then computing the extent.
165+
166+
The exact and approximate extent are the same if the transformation doesn't contain any rotation or shear, or in the
167+
case in which the transformation is affine but all the corners of the extent of the untransformed data
168+
(bounding box corners) are part of the dataset itself. Note that this is always the case for raster data.
169+
170+
An extreme case is a dataset composed of the two points (0, 0) and (1, 1), rotated anticlockwise by 45 degrees. The
171+
exact extent is the bounding box [minx, miny, maxx, maxy] = [0, 0, 0, 1.414], while the approximate extent is the
172+
box [minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414].
131173
"""
132174
raise ValueError("The object type is not supported.")
133175

@@ -136,11 +178,12 @@ def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global
136178
def _(
137179
e: SpatialData,
138180
coordinate_system: str = "global",
181+
exact: bool = True,
139182
has_images: bool = True,
140183
has_labels: bool = True,
141184
has_points: bool = True,
142185
has_shapes: bool = True,
143-
elements: Union[list[str], None] = None,
186+
elements: list[str] | None = None,
144187
) -> BoundingBoxDescription:
145188
"""
146189
Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements.
@@ -174,7 +217,10 @@ def _(
174217
assert isinstance(transformations, dict)
175218
coordinate_systems = list(transformations.keys())
176219
if coordinate_system in coordinate_systems:
177-
extent = get_extent(element_obj, coordinate_system=coordinate_system)
220+
if isinstance(element_obj, (DaskDataFrame, GeoDataFrame)):
221+
extent = get_extent(element_obj, coordinate_system=coordinate_system, exact=exact)
222+
else:
223+
extent = get_extent(element_obj, coordinate_system=coordinate_system)
178224
axes = list(extent.keys())
179225
for ax in axes:
180226
new_min_coordinates_dict[ax] += [extent[ax][0]]
@@ -183,8 +229,14 @@ def _(
183229
raise ValueError(
184230
f"The SpatialData object does not contain any element in the "
185231
f" coordinate system {coordinate_system!r}, "
186-
f"please pass a different coordinate system wiht the argument 'coordinate_system'."
232+
f"please pass a different coordinate system with the argument 'coordinate_system'."
187233
)
234+
if len(new_min_coordinates_dict) == 0:
235+
raise ValueError(
236+
f"The SpatialData object does not contain any element in the coordinate system {coordinate_system!r}, "
237+
"please pass a different coordinate system with the argument 'coordinate_system'."
238+
)
239+
axes = list(new_min_coordinates_dict.keys())
188240
new_min_coordinates = np.array([min(new_min_coordinates_dict[ax]) for ax in axes])
189241
new_max_coordinates = np.array([max(new_max_coordinates_dict[ax]) for ax in axes])
190242
extent = {}
@@ -193,8 +245,21 @@ def _(
193245
return extent
194246

195247

248+
def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription:
249+
# remove potentially empty geometries
250+
e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)]
251+
assert len(e_temp) > 0, "Cannot compute extent of an empty collection of geometries."
252+
253+
# separate points from (multi-)polygons
254+
first_geometry = e_temp["geometry"].iloc[0]
255+
if isinstance(first_geometry, Point):
256+
return _get_extent_of_circles(e)
257+
assert isinstance(first_geometry, (Polygon, MultiPolygon))
258+
return _get_extent_of_polygons_multipolygons(e)
259+
260+
196261
@get_extent.register
197-
def _(e: GeoDataFrame, coordinate_system: str = "global") -> BoundingBoxDescription:
262+
def _(e: GeoDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription:
198263
"""
199264
Compute the extent (bounding box) of a set of shapes.
200265
@@ -203,57 +268,33 @@ def _(e: GeoDataFrame, coordinate_system: str = "global") -> BoundingBoxDescript
203268
The bounding box description.
204269
"""
205270
_check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system)
206-
# remove potentially empty geometries
207-
e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)]
208-
209-
# separate points from (multi-)polygons
210-
e_points = e_temp[e_temp["geometry"].apply(lambda geom: isinstance(geom, Point))]
211-
e_polygons = e_temp[e_temp["geometry"].apply(lambda geom: isinstance(geom, (Polygon, MultiPolygon)))]
212-
extent = None
213-
if len(e_points) > 0:
214-
assert "radius" in e_points.columns, "Shapes that are points must have a 'radius' column."
215-
extent = _get_extent_of_circles(e_points)
216-
if len(e_polygons) > 0:
217-
extent_polygons = _get_extent_of_polygons_multipolygons(e_polygons)
218-
if extent is None:
219-
extent = extent_polygons
220-
else:
221-
# case when there are points AND (multi-)polygons in the GeoDataFrame
222-
extent["y"] = (min(extent["y"][0], extent_polygons["y"][0]), max(extent["y"][1], extent_polygons["y"][1]))
223-
extent["x"] = (min(extent["x"][0], extent_polygons["x"][0]), max(extent["x"][1], extent_polygons["x"][1]))
224-
225-
if extent is None:
226-
raise ValueError(
227-
"Unable to compute extent of GeoDataFrame. It needs to contain at least one non-empty "
228-
"Point or Polygon or Multipolygon."
271+
if not exact:
272+
extent = _get_extent_of_shapes(e)
273+
return _compute_extent_in_coordinate_system(
274+
element=e,
275+
coordinate_system=coordinate_system,
276+
extent=extent,
229277
)
230-
231-
min_coordinates = [extent["y"][0], extent["x"][0]]
232-
max_coordinates = [extent["y"][1], extent["x"][1]]
233-
axes = tuple(extent.keys())
234-
235-
return _compute_extent_in_coordinate_system(
236-
element=e_temp,
237-
coordinate_system=coordinate_system,
238-
min_coordinates=np.array(min_coordinates),
239-
max_coordinates=np.array(max_coordinates),
240-
axes=axes,
241-
)
278+
t = get_transformation(e, to_coordinate_system=coordinate_system)
279+
assert isinstance(t, BaseTransformation)
280+
transformed = transform(e, t)
281+
return _get_extent_of_shapes(transformed)
242282

243283

244284
@get_extent.register
245-
def _(e: DaskDataFrame, coordinate_system: str = "global") -> BoundingBoxDescription:
285+
def _(e: DaskDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription:
246286
_check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system)
247-
axes = get_axes_names(e)
248-
min_coordinates = np.array([e[ax].min().compute() for ax in axes])
249-
max_coordinates = np.array([e[ax].max().compute() for ax in axes])
250-
return _compute_extent_in_coordinate_system(
251-
element=e,
252-
coordinate_system=coordinate_system,
253-
min_coordinates=min_coordinates,
254-
max_coordinates=max_coordinates,
255-
axes=axes,
256-
)
287+
if not exact:
288+
extent = _get_extent_of_points(e)
289+
return _compute_extent_in_coordinate_system(
290+
element=e,
291+
coordinate_system=coordinate_system,
292+
extent=extent,
293+
)
294+
t = get_transformation(e, to_coordinate_system=coordinate_system)
295+
assert isinstance(t, BaseTransformation)
296+
transformed = transform(e, t)
297+
return _get_extent_of_points(transformed)
257298

258299

259300
@get_extent.register
@@ -275,16 +316,12 @@ def _check_element_has_coordinate_system(element: SpatialElement, coordinate_sys
275316
if coordinate_system not in coordinate_systems:
276317
raise ValueError(
277318
f"The element does not contain any coordinate system named {coordinate_system!r}, "
278-
f"please pass a different coordinate system wiht the argument 'coordinate_system'."
319+
f"please pass a different coordinate system with the argument 'coordinate_system'."
279320
)
280321

281322

282323
def _compute_extent_in_coordinate_system(
283-
element: SpatialElement | DataArray,
284-
coordinate_system: str,
285-
min_coordinates: ArrayLike,
286-
max_coordinates: ArrayLike,
287-
axes: tuple[str, ...],
324+
element: SpatialElement | DataArray, coordinate_system: str, extent: BoundingBoxDescription
288325
) -> BoundingBoxDescription:
289326
"""
290327
Transform the extent from the intrinsic coordinates of the element to the given coordinate system.
@@ -295,12 +332,8 @@ def _compute_extent_in_coordinate_system(
295332
The SpatialElement.
296333
coordinate_system
297334
The coordinate system to transform the extent to.
298-
min_coordinates
299-
Min coordinates of the extent in the intrinsic coordinates of the element, expects [y_min, x_min].
300-
max_coordinates
301-
Max coordinates of the extent in the intrinsic coordinates of the element, expects [y_max, x_max].
302-
axes
303-
The min and max coordinates refer to.
335+
extent
336+
The extent in the intrinsic coordinates of the element.
304337
305338
Returns
306339
-------
@@ -310,6 +343,11 @@ def _compute_extent_in_coordinate_system(
310343
assert isinstance(transformation, BaseTransformation)
311344
from spatialdata._core.query._utils import get_bounding_box_corners
312345

346+
axes = get_axes_names(element)
347+
if "c" in axes:
348+
axes = tuple(ax for ax in axes if ax != "c")
349+
min_coordinates = np.array([extent[ax][0] for ax in axes])
350+
max_coordinates = np.array([extent[ax][1] for ax in axes])
313351
corners = get_bounding_box_corners(
314352
axes=axes,
315353
min_coordinate=min_coordinates,

src/spatialdata/models/_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import dask.dataframe as dd
77
import geopandas
8-
from anndata import AnnData
98
from dask.dataframe import DataFrame as DaskDataFrame
109
from geopandas import GeoDataFrame
1110
from multiscale_spatial_image import MultiscaleSpatialImage
@@ -167,15 +166,15 @@ def _(e: MultiscaleSpatialImage) -> tuple[str, ...]:
167166

168167
@get_axes_names.register(GeoDataFrame)
169168
def _(e: GeoDataFrame) -> tuple[str, ...]:
170-
all_dims = (Z, Y, X)
169+
all_dims = (X, Y, Z)
171170
n = e.geometry.iloc[0]._ndim
172-
dims = all_dims[-n:]
171+
dims = all_dims[:n]
173172
_validate_dims(dims)
174173
return dims
175174

176175

177176
@get_axes_names.register(DaskDataFrame)
178-
def _(e: AnnData) -> tuple[str, ...]:
177+
def _(e: DaskDataFrame) -> tuple[str, ...]:
179178
valid_dims = (X, Y, Z)
180179
dims = tuple([c for c in valid_dims if c in e.columns])
181180
_validate_dims(dims)

0 commit comments

Comments
 (0)