Skip to content

Commit b82e7f5

Browse files
Sonja-StockhausSonja Stockhaus
andauthored
improved cmap handling for images (#194)
* improved cmap handling for images * update/add tests * update tests * changelog --------- Co-authored-by: Sonja Stockhaus <stockhaus@cip.ifi.lmu.de>
1 parent 0e7360c commit b82e7f5

File tree

10 files changed

+43
-8
lines changed

10 files changed

+43
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning][].
2020
### Fixed
2121

2222
- Now dropping index when plotting shapes after spatial query (#177)
23+
- User can now pass Colormap objects to the cmap argument in render_images. When only one cmap is given for 3 channels, it is now applied to each channel (#188, #194)
2324

2425
## [0.0.6] - 2023-11-06
2526

src/spatialdata_plot/pl/basic.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,6 @@ def render_images(
362362
sdata = _verify_plotting_tree(sdata)
363363
n_steps = len(sdata.plotting_tree.keys())
364364

365-
if channel is None and cmap is None:
366-
cmap = "brg"
367-
368365
cmap_params: list[CmapParams] | CmapParams
369366
if isinstance(cmap, list):
370367
cmap_params = [

src/spatialdata_plot/pl/render.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,10 +437,29 @@ def _render_images(
437437
if render_params.cmap_params[i].norm is not None:
438438
layers[c] = render_params.cmap_params[i].norm(layers[c])
439439

440-
# 2A) Image has 3 channels, no palette/cmap info -> use RGB
441-
if n_channels == 3 and render_params.palette is None and not got_multiple_cmaps:
440+
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
441+
if n_channels == 3 and render_params.palette is None and not isinstance(render_params.cmap_params, list):
442+
if render_params.cmap_params.is_default: # -> use RGB
443+
stacked = np.stack([layers[c] for c in channels], axis=-1)
444+
else: # -> use given cmap for each channel
445+
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
446+
# Apply cmaps to each channel, add up and normalize to [0, 1]
447+
stacked = (
448+
np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) / n_channels
449+
)
450+
# Remove alpha channel so we can overwrite it from render_params.alpha
451+
stacked = stacked[:, :, :3]
452+
logger.warning(
453+
"One cmap was given for multiple channels and is now used for each channel. "
454+
"You're blending multiple cmaps. "
455+
"If the plot doesn't look like you expect, it might be because your "
456+
"cmaps go from a given color to 'white', and not to 'transparent'. "
457+
"Therefore, the 'white' of higher layers will overlay the lower layers. "
458+
"Consider using 'palette' instead."
459+
)
460+
442461
im = ax.imshow(
443-
np.stack([layers[c] for c in channels], axis=-1),
462+
stacked,
444463
alpha=render_params.alpha,
445464
)
446465
im.set_transform(trans_data)

src/spatialdata_plot/pl/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,13 @@ def _prepare_cmap_norm(
344344
**kwargs: Any,
345345
) -> CmapParams:
346346
is_default = cmap is None
347-
cmap = copy(matplotlib.colormaps[rcParams["image.cmap"] if cmap is None else cmap])
347+
if cmap is None:
348+
cmap = rcParams["image.cmap"]
349+
if isinstance(cmap, str):
350+
cmap = matplotlib.colormaps[cmap]
351+
352+
cmap = copy(cmap)
353+
348354
cmap.set_bad("lightgray" if na_color is None else na_color)
349355

350356
if isinstance(norm, Normalize) or not norm:
32 KB
Loading
31.8 KB
Loading
-37.6 KB
Binary file not shown.
32 KB
Loading
31.8 KB
Loading

tests/pl/test_render_images.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,21 @@ class TestImages(PlotTester, metaclass=PlotTesterMeta):
2424
def test_plot_can_render_image(self, sdata_blobs: SpatialData):
2525
sdata_blobs.pl.render_images(elements="blobs_image").pl.show()
2626

27-
def test_plot_can_pass_cmap_to_render_images(self, sdata_blobs: SpatialData):
27+
def test_plot_can_pass_str_cmap(self, sdata_blobs: SpatialData):
2828
sdata_blobs.pl.render_images(elements="blobs_image", cmap="seismic").pl.show()
2929

30+
def test_plot_can_pass_cmap(self, sdata_blobs: SpatialData):
31+
sdata_blobs.pl.render_images(elements="blobs_image", cmap=matplotlib.colormaps["seismic"]).pl.show()
32+
33+
def test_plot_can_pass_str_cmap_list(self, sdata_blobs: SpatialData):
34+
sdata_blobs.pl.render_images(elements="blobs_image", cmap=["seismic", "Reds", "Blues"]).pl.show()
35+
36+
def test_plot_can_pass_cmap_list(self, sdata_blobs: SpatialData):
37+
sdata_blobs.pl.render_images(
38+
elements="blobs_image",
39+
cmap=[matplotlib.colormaps["seismic"], matplotlib.colormaps["Reds"], matplotlib.colormaps["Blues"]],
40+
).pl.show()
41+
3042
def test_plot_can_render_a_single_channel_from_image(self, sdata_blobs: SpatialData):
3143
sdata_blobs.pl.render_images(elements="blobs_image", channel=0).pl.show()
3244

0 commit comments

Comments
 (0)