Skip to content

Commit 6b2c484

Browse files
authored
Fix shapes datashader colorbar exceeding data range (#561)
1 parent f2bff29 commit 6b2c484

File tree

6 files changed

+102
-48
lines changed

6 files changed

+102
-48
lines changed

src/spatialdata_plot/_logging.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,40 @@ def logger_warns(
113113
if not any(pattern.search(r.getMessage()) for r in records):
114114
msgs = [r.getMessage() for r in records]
115115
raise AssertionError(f"Did not find log matching {match!r} in records: {msgs!r}")
116+
117+
118+
@contextmanager
119+
def logger_no_warns(
120+
caplog: LogCaptureFixture,
121+
logger: logging.Logger,
122+
match: str | None = None,
123+
level: int = logging.WARNING,
124+
) -> Iterator[None]:
125+
"""Assert that no log record matching *match* is emitted.
126+
127+
Counterpart to :func:`logger_warns`.
128+
"""
129+
initial_record_count = len(caplog.records)
130+
131+
handler = caplog.handler
132+
logger.addHandler(handler)
133+
original_level = logger.level
134+
logger.setLevel(level)
135+
136+
with caplog.at_level(level, logger=logger.name):
137+
try:
138+
yield
139+
finally:
140+
logger.removeHandler(handler)
141+
logger.setLevel(original_level)
142+
143+
records = [r for r in caplog.records[initial_record_count:] if r.levelno >= level]
144+
145+
if match is not None:
146+
pattern = re.compile(match)
147+
matching = [r.getMessage() for r in records if pattern.search(r.getMessage())]
148+
if matching:
149+
raise AssertionError(f"Found unexpected log matching {match!r}: {matching!r}")
150+
elif records:
151+
msgs = [r.getMessage() for r in records]
152+
raise AssertionError(f"Expected no log records at level>={level}, but got: {msgs!r}")

src/spatialdata_plot/pl/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ def render_shapes(
272272
273273
datashader_reduction : Literal[
274274
"sum", "mean", "any", "count", "std", "var", "max", "min"
275-
], default: "sum"
276-
Reduction method for datashader when coloring by continuous values. Defaults to 'sum'.
275+
], default: "max"
276+
Reduction method for datashader when coloring by continuous values. Defaults to 'max'.
277277
278278
279279
Notes

src/spatialdata_plot/pl/render.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_ds_aggregate,
3535
_ds_shade_categorical,
3636
_ds_shade_continuous,
37+
_DsReduction,
3738
_render_ds_image,
3839
_render_ds_outlines,
3940
)
@@ -82,17 +83,26 @@ def _want_decorations(color_vector: Any, na_color: Color) -> bool:
8283
cv = np.asarray(color_vector)
8384
if cv.size == 0:
8485
return False
85-
# Fast check: if any value differs from the first, there is variety → show decorations.
8686
first = cv.flat[0]
8787
if not (cv == first).all():
8888
return True
89-
# All values are the same — suppress decorations when that value is the NA color.
9089
na_hex = na_color.get_hex()
9190
if isinstance(first, str) and first.startswith("#") and na_hex.startswith("#"):
9291
return _hex_no_alpha(first) != _hex_no_alpha(na_hex)
9392
return bool(first != na_hex)
9493

9594

95+
def _log_datashader_method(method: str, ds_reduction: _DsReduction | None, default: _DsReduction) -> None:
96+
"""Log the datashader backend and effective reduction being used."""
97+
effective = ds_reduction if ds_reduction is not None else default
98+
logger.info(
99+
f"Using '{method}' backend with '{effective}' as reduction"
100+
" method to speed up plotting. Depending on the reduction method, the value"
101+
" range of the plot might change. Set method to 'matplotlib' to disable"
102+
" this behaviour."
103+
)
104+
105+
96106
def _reparse_points(
97107
sdata_filt: sd.SpatialData,
98108
element: str,
@@ -437,14 +447,10 @@ def _render_shapes(
437447
if method is None:
438448
method = "datashader" if len(shapes) > 10000 else "matplotlib"
439449

450+
_default_reduction: _DsReduction = "max"
451+
440452
if method != "matplotlib":
441-
# we only notify the user when we switched away from matplotlib
442-
logger.info(
443-
f"Using '{method}' backend with '{render_params.ds_reduction}' as reduction"
444-
" method to speed up plotting. Depending on the reduction method, the value"
445-
" range of the plot might change. Set method to 'matplotlib' to disable"
446-
" this behaviour."
447-
)
453+
_log_datashader_method(method, render_params.ds_reduction, _default_reduction)
448454

449455
if method == "datashader":
450456
_geometry = shapes["geometry"]
@@ -518,7 +524,7 @@ def _render_shapes(
518524
col_for_color,
519525
color_by_categorical,
520526
render_params.ds_reduction,
521-
"mean",
527+
_default_reduction,
522528
"shapes",
523529
)
524530

@@ -796,8 +802,7 @@ def _render_points(
796802
# from the registered points (see above) avoids duplicate-origin ambiguities.
797803
color_table_name = table_name
798804

799-
# When color was already loaded from a table (line 690), pass it directly
800-
# to avoid a redundant get_values() call inside _set_color_source_vec.
805+
# Reuse color data already loaded from the table to avoid a redundant get_values() call.
801806
_preloaded = points_pd_with_color[col_for_color] if added_color_from_table and col_for_color is not None else None
802807

803808
color_source_vector, color_vector, _ = _set_color_source_vec(
@@ -852,14 +857,10 @@ def _render_points(
852857
if method is None:
853858
method = "datashader" if n_points > 10000 else "matplotlib"
854859

860+
_default_reduction: _DsReduction = "sum"
861+
855862
if method == "datashader":
856-
# we only notify the user when we switched away from matplotlib
857-
logger.info(
858-
f"Using '{method}' backend with '{render_params.ds_reduction}' as reduction"
859-
" method to speed up plotting. Depending on the reduction method, the value"
860-
" range of the plot might change. Set method to 'matplotlib' do disable"
861-
" this behaviour."
862-
)
863+
_log_datashader_method(method, render_params.ds_reduction, _default_reduction)
863864

864865
# NOTE: s in matplotlib is in units of points**2
865866
# use dpi/100 as a factor for cases where dpi!=100
@@ -918,7 +919,7 @@ def _render_points(
918919
col_for_color,
919920
color_by_categorical,
920921
render_params.ds_reduction,
921-
"sum",
922+
_default_reduction,
922923
"points",
923924
)
924925

src/spatialdata_plot/pl/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2507,9 +2507,6 @@ def _ensure_table_and_layer_exist_in_sdata(
25072507
if ds_reduction and (ds_reduction not in valid_ds_reduction_methods):
25082508
raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.")
25092509

2510-
if method == "datashader" and ds_reduction is None:
2511-
param_dict["ds_reduction"] = "sum"
2512-
25132510
return param_dict
25142511

25152512

tests/pl/test_render_points.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32

43
import dask.dataframe
@@ -24,7 +23,7 @@
2423
from spatialdata.transformations._utils import _set_transformations
2524

2625
import spatialdata_plot # noqa: F401
27-
from spatialdata_plot._logging import logger, logger_warns
26+
from spatialdata_plot._logging import logger, logger_no_warns, logger_warns
2827
from spatialdata_plot.pl._datashader import (
2928
_build_datashader_color_key,
3029
_ds_aggregate,
@@ -832,13 +831,8 @@ def test_ds_reduction_ignored_for_categorical(caplog):
832831
def test_ds_reduction_no_warning_when_none(caplog):
833832
"""No spurious warning when ds_reduction is None (the default)."""
834833
cvs, df = _make_ds_canvas_and_df()
835-
with caplog.at_level(logging.WARNING, logger=logger.name):
836-
logger.addHandler(caplog.handler)
837-
try:
838-
_ds_aggregate(cvs, df.copy(), "cat", True, None, "sum", "points")
839-
finally:
840-
logger.removeHandler(caplog.handler)
841-
assert not any("ignored" in r.message.lower() for r in caplog.records)
834+
with logger_no_warns(caplog, logger, match="ignored"):
835+
_ds_aggregate(cvs, df.copy(), "cat", True, None, "sum", "points")
842836

843837

844838
@pytest.mark.parametrize("reduction", ["mean", "max", "min", "count", "std", "var"])
@@ -866,13 +860,8 @@ def test_warn_groups_ignored_continuous_emits(caplog):
866860

867861
def test_warn_groups_ignored_continuous_silent_for_categorical(caplog):
868862
"""No warning when color_source_vector is present (categorical)."""
869-
with caplog.at_level(logging.WARNING, logger=logger.name):
870-
logger.addHandler(caplog.handler)
871-
try:
872-
_warn_groups_ignored_continuous(["A"], pd.Categorical(["A", "B"]), "cat_col")
873-
finally:
874-
logger.removeHandler(caplog.handler)
875-
assert not any("ignored" in r.message for r in caplog.records)
863+
with logger_no_warns(caplog, logger, match="ignored"):
864+
_warn_groups_ignored_continuous(["A"], pd.Categorical(["A", "B"]), "cat_col")
876865

877866

878867
def test_color_key_warns_on_short_color_vector(caplog):
@@ -893,13 +882,8 @@ def test_color_key_warns_on_long_color_vector(caplog):
893882
def test_color_key_no_warning_when_lengths_match(caplog):
894883
"""No warning when lengths match."""
895884
cat = pd.Categorical(["A", "B", "C"])
896-
with caplog.at_level(logging.WARNING, logger=logger.name):
897-
logger.addHandler(caplog.handler)
898-
try:
899-
_build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff"], "#cccccc")
900-
finally:
901-
logger.removeHandler(caplog.handler)
902-
assert not any("color_vector length" in r.message for r in caplog.records)
885+
with logger_no_warns(caplog, logger, match="color_vector length"):
886+
_build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff"], "#cccccc")
903887

904888

905889
def test_color_key_unseen_category_gets_na_color(caplog):

tests/pl/test_render_shapes.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,3 +1175,38 @@ def test_render_shapes_color_with_conflicting_index_name():
11751175

11761176
# Should not raise ValueError: cannot insert EntityID, already exists
11771177
sdata.pl.render_shapes("shapes", color="cell_type", table_name="table").pl.show()
1178+
1179+
1180+
def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
1181+
"""Datashader colorbar range must not exceed the actual data range for shapes.
1182+
1183+
Regression test for https://github.com/scverse/spatialdata-plot/issues/559.
1184+
Before the fix, shapes defaulted to 'sum' aggregation, causing overlapping
1185+
shapes to inflate the colorbar beyond the true data maximum.
1186+
"""
1187+
n = len(sdata_blobs.shapes["blobs_circles"])
1188+
rng = np.random.default_rng(0)
1189+
values = rng.uniform(0, 100, size=n)
1190+
sdata_blobs.shapes["blobs_circles"]["continuous_val"] = values
1191+
data_max = float(values.max())
1192+
data_min = float(values.min())
1193+
1194+
fig, ax = plt.subplots()
1195+
sdata_blobs.pl.render_shapes("blobs_circles", color="continuous_val", method="datashader").pl.show(ax=ax)
1196+
1197+
# Find the colorbar axis — it's a child axes with a ScalarMappable
1198+
cbar_vmax = None
1199+
cbar_vmin = None
1200+
for child in fig.get_children():
1201+
if isinstance(child, matplotlib.axes.Axes) and child is not ax:
1202+
ylim = child.get_ylim()
1203+
if ylim != (0.0, 1.0): # colorbar axes have non-default limits
1204+
cbar_vmin, cbar_vmax = ylim
1205+
1206+
assert cbar_vmax is not None, "Could not find colorbar in figure"
1207+
assert cbar_vmax <= data_max * 1.01, (
1208+
f"Colorbar max ({cbar_vmax:.2f}) exceeds data max ({data_max:.2f}); "
1209+
"datashader aggregation is likely using 'sum' instead of 'max'"
1210+
)
1211+
assert cbar_vmin >= data_min * 0.99 - 0.01, f"Colorbar min ({cbar_vmin:.2f}) is below data min ({data_min:.2f})"
1212+
plt.close(fig)

0 commit comments

Comments
 (0)