Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_FontWeight,
)
from spatialdata_plot.pl.utils import (
_RENDER_CMD_TO_CS_FLAG,
_get_cs_contents,
_get_elements_to_be_rendered,
_get_valid_cs,
Expand Down Expand Up @@ -993,9 +994,11 @@ def show(
ax_x_min, ax_x_max = ax.get_xlim()
ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left

coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
cs_was_auto = coordinate_systems is None
coordinate_systems = list(sdata.coordinate_systems) if cs_was_auto else coordinate_systems
if isinstance(coordinate_systems, str):
coordinate_systems = [coordinate_systems]
assert coordinate_systems is not None

for cs in coordinate_systems:
if cs not in sdata.coordinate_systems:
Expand All @@ -1019,14 +1022,32 @@ def show(
elements=elements_to_be_rendered,
)

# catch error in ruff-friendly way
if ax is not None: # we'll generate matching number then
# When CS was auto-detected and ax is provided, keep only CS that have
# element types for ALL render commands (workaround for upstream #176).
if ax is not None:
n_ax = 1 if isinstance(ax, Axes) else len(ax)
if cs_was_auto and len(coordinate_systems) > n_ax:
required_flags = [_RENDER_CMD_TO_CS_FLAG[cmd] for cmd in cmds if cmd in _RENDER_CMD_TO_CS_FLAG]
strict_cs = [
cs_name
for cs_name in coordinate_systems
if all(cs_contents.query(f"cs == '{cs_name}'").iloc[0][flag] for flag in required_flags)
]
if strict_cs:
coordinate_systems = strict_cs

if len(coordinate_systems) != n_ax:
raise ValueError(
msg = (
f"Mismatch between number of matplotlib axes objects ({n_ax}) "
f"and number of coordinate systems ({len(coordinate_systems)})."
)
if cs_was_auto:
msg += (
" This can happen when elements have transformations to multiple "
"coordinate systems (e.g. after filter_by_coordinate_system). "
"Pass `coordinate_systems=` explicitly to select which ones to plot."
)
raise ValueError(msg)

# set up canvas
fig_params, scalebar_params = _prepare_params_plot(
Expand Down
16 changes: 9 additions & 7 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@

_GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name."

_RENDER_CMD_TO_CS_FLAG: dict[str, str] = {
"render_images": "has_images",
"render_shapes": "has_shapes",
"render_points": "has_points",
"render_labels": "has_labels",
}


def _gate_palette_and_groups(
element_params: dict[str, Any],
Expand Down Expand Up @@ -264,6 +271,7 @@ def _prepare_params_plot(
if ax is not None and len(ax) != num_panels:
raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.")
if fig is None:
# TODO(#579): infer fig from ax[0].get_figure() instead of requiring it
raise ValueError(
f"Invalid value of `fig`: {fig}. If a list of `Axes` is passed, a `Figure` must also be specified."
)
Expand Down Expand Up @@ -2080,17 +2088,11 @@ def _get_elements_to_be_rendered(
List of names of the SpatialElements to be rendered in the plot.
"""
elements_to_be_rendered: list[str] = []
render_cmds_map = {
"render_images": "has_images",
"render_shapes": "has_shapes",
"render_points": "has_points",
"render_labels": "has_labels",
}

cs_query = cs_contents.query(f"cs == '{cs}'")

for cmd, params in render_cmds:
key = render_cmds_map.get(cmd)
key = _RENDER_CMD_TO_CS_FLAG.get(cmd)
if key and cs_query[key][0]:
elements_to_be_rendered += [params.element]

Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,31 @@ def _get_sdata_with_multiple_images(share_coordinate_system: str = "all"):
return _get_sdata_with_multiple_images


@pytest.fixture
def sdata_multi_cs():
"""SpatialData with an image in one CS and shapes in two CS.

Useful for testing behaviour when elements have transformations to
different sets of coordinate systems (e.g. after
``filter_by_coordinate_system``).
"""
from shapely.geometry import Point

image = Image2DModel.parse(
np.zeros((1, 10, 10)),
dims=("c", "y", "x"),
transformations={"aligned": sd.transformations.Identity()},
)
shapes = ShapesModel.parse(
GeoDataFrame(geometry=[Point(5, 5)], data={"radius": [2]}),
transformations={
"aligned": sd.transformations.Identity(),
"global": sd.transformations.Identity(),
},
)
return SpatialData(images={"img": image}, shapes={"shp": shapes})


@pytest.fixture
def sdata_hexagonal_grid_spots():
"""Create a hexagonal grid of points for testing visium_hex functionality."""
Expand Down
35 changes: 35 additions & 0 deletions tests/pl/test_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,38 @@ def test_keyerror_when_shape_element_does_not_exist(request):

with pytest.raises(KeyError):
sdata.pl.render_shapes(element="not_found").pl.show()


# Regression tests for #176: plotting with user-supplied ax when elements
# have transformations to multiple coordinate systems.


def test_single_ax_after_filter_by_coordinate_system(sdata_multi_cs):
"""After filter_by_coordinate_system, single ax should work without specifying CS."""
sdata_filt = sdata_multi_cs.filter_by_coordinate_system("aligned")

_, ax = plt.subplots(1, 1)
sdata_filt.pl.render_images("img").pl.render_shapes("shp").pl.show(ax=ax)
assert ax.get_title() == "aligned"


def test_single_ax_with_explicit_cs(sdata_multi_cs):
"""Explicit coordinate_systems with single ax should work."""
_, ax = plt.subplots(1, 1)
sdata_multi_cs.pl.render_images("img").pl.render_shapes("shp").pl.show(ax=ax, coordinate_systems="aligned")
assert ax.get_title() == "aligned"


def test_single_ax_explicit_multi_cs_raises(sdata_multi_cs):
"""Explicitly requesting more CS than axes should still raise."""
_, ax = plt.subplots(1, 1)
with pytest.raises(ValueError, match="Mismatch"):
sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax, coordinate_systems=["aligned", "global"])


def test_single_ax_auto_cs_unresolvable_raises(sdata_multi_cs):
"""When strict filtering can't resolve the mismatch, error includes hint."""
_, ax = plt.subplots(1, 1)
with pytest.raises(ValueError, match="coordinate_systems="):
# Only render shapes (present in both CS), so strict filter can't narrow down
sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax)
Loading