diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index c8d68089..150d3a49 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -51,6 +51,7 @@ _FontWeight, ) from spatialdata_plot.pl.utils import ( + _RENDER_CMD_TO_CS_FLAG, _get_cs_contents, _get_elements_to_be_rendered, _get_valid_cs, @@ -993,9 +994,11 @@ def show( ax_x_min, ax_x_max = ax.get_xlim() ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left - coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems + cs_was_auto = coordinate_systems is None + coordinate_systems = list(sdata.coordinate_systems) if cs_was_auto else coordinate_systems if isinstance(coordinate_systems, str): coordinate_systems = [coordinate_systems] + assert coordinate_systems is not None for cs in coordinate_systems: if cs not in sdata.coordinate_systems: @@ -1019,14 +1022,32 @@ def show( elements=elements_to_be_rendered, ) - # catch error in ruff-friendly way - if ax is not None: # we'll generate matching number then + # When CS was auto-detected and ax is provided, keep only CS that have + # element types for ALL render commands (workaround for upstream #176). + if ax is not None: n_ax = 1 if isinstance(ax, Axes) else len(ax) + if cs_was_auto and len(coordinate_systems) > n_ax: + required_flags = [_RENDER_CMD_TO_CS_FLAG[cmd] for cmd in cmds if cmd in _RENDER_CMD_TO_CS_FLAG] + strict_cs = [ + cs_name + for cs_name in coordinate_systems + if all(cs_contents.query(f"cs == '{cs_name}'").iloc[0][flag] for flag in required_flags) + ] + if strict_cs: + coordinate_systems = strict_cs + if len(coordinate_systems) != n_ax: - raise ValueError( + msg = ( f"Mismatch between number of matplotlib axes objects ({n_ax}) " f"and number of coordinate systems ({len(coordinate_systems)})." ) + if cs_was_auto: + msg += ( + " This can happen when elements have transformations to multiple " + "coordinate systems (e.g. after filter_by_coordinate_system). " + "Pass `coordinate_systems=` explicitly to select which ones to plot." + ) + raise ValueError(msg) # set up canvas fig_params, scalebar_params = _prepare_params_plot( diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 6068533f..b77144d9 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -95,6 +95,13 @@ _GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name." +_RENDER_CMD_TO_CS_FLAG: dict[str, str] = { + "render_images": "has_images", + "render_shapes": "has_shapes", + "render_points": "has_points", + "render_labels": "has_labels", +} + def _gate_palette_and_groups( element_params: dict[str, Any], @@ -264,6 +271,7 @@ def _prepare_params_plot( if ax is not None and len(ax) != num_panels: raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.") if fig is None: + # TODO(#579): infer fig from ax[0].get_figure() instead of requiring it raise ValueError( f"Invalid value of `fig`: {fig}. If a list of `Axes` is passed, a `Figure` must also be specified." ) @@ -2080,17 +2088,11 @@ def _get_elements_to_be_rendered( List of names of the SpatialElements to be rendered in the plot. """ elements_to_be_rendered: list[str] = [] - render_cmds_map = { - "render_images": "has_images", - "render_shapes": "has_shapes", - "render_points": "has_points", - "render_labels": "has_labels", - } cs_query = cs_contents.query(f"cs == '{cs}'") for cmd, params in render_cmds: - key = render_cmds_map.get(cmd) + key = _RENDER_CMD_TO_CS_FLAG.get(cmd) if key and cs_query[key][0]: elements_to_be_rendered += [params.element] diff --git a/tests/conftest.py b/tests/conftest.py index ac50b959..b70dc567 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -631,6 +631,31 @@ def _get_sdata_with_multiple_images(share_coordinate_system: str = "all"): return _get_sdata_with_multiple_images +@pytest.fixture +def sdata_multi_cs(): + """SpatialData with an image in one CS and shapes in two CS. + + Useful for testing behaviour when elements have transformations to + different sets of coordinate systems (e.g. after + ``filter_by_coordinate_system``). + """ + from shapely.geometry import Point + + image = Image2DModel.parse( + np.zeros((1, 10, 10)), + dims=("c", "y", "x"), + transformations={"aligned": sd.transformations.Identity()}, + ) + shapes = ShapesModel.parse( + GeoDataFrame(geometry=[Point(5, 5)], data={"radius": [2]}), + transformations={ + "aligned": sd.transformations.Identity(), + "global": sd.transformations.Identity(), + }, + ) + return SpatialData(images={"img": image}, shapes={"shp": shapes}) + + @pytest.fixture def sdata_hexagonal_grid_spots(): """Create a hexagonal grid of points for testing visium_hex functionality.""" diff --git a/tests/pl/test_render.py b/tests/pl/test_render.py index 4ada7268..83c6ee3c 100644 --- a/tests/pl/test_render.py +++ b/tests/pl/test_render.py @@ -62,3 +62,38 @@ def test_keyerror_when_shape_element_does_not_exist(request): with pytest.raises(KeyError): sdata.pl.render_shapes(element="not_found").pl.show() + + +# Regression tests for #176: plotting with user-supplied ax when elements +# have transformations to multiple coordinate systems. + + +def test_single_ax_after_filter_by_coordinate_system(sdata_multi_cs): + """After filter_by_coordinate_system, single ax should work without specifying CS.""" + sdata_filt = sdata_multi_cs.filter_by_coordinate_system("aligned") + + _, ax = plt.subplots(1, 1) + sdata_filt.pl.render_images("img").pl.render_shapes("shp").pl.show(ax=ax) + assert ax.get_title() == "aligned" + + +def test_single_ax_with_explicit_cs(sdata_multi_cs): + """Explicit coordinate_systems with single ax should work.""" + _, ax = plt.subplots(1, 1) + sdata_multi_cs.pl.render_images("img").pl.render_shapes("shp").pl.show(ax=ax, coordinate_systems="aligned") + assert ax.get_title() == "aligned" + + +def test_single_ax_explicit_multi_cs_raises(sdata_multi_cs): + """Explicitly requesting more CS than axes should still raise.""" + _, ax = plt.subplots(1, 1) + with pytest.raises(ValueError, match="Mismatch"): + sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax, coordinate_systems=["aligned", "global"]) + + +def test_single_ax_auto_cs_unresolvable_raises(sdata_multi_cs): + """When strict filtering can't resolve the mismatch, error includes hint.""" + _, ax = plt.subplots(1, 1) + with pytest.raises(ValueError, match="coordinate_systems="): + # Only render shapes (present in both CS), so strict filter can't narrow down + sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax)