-
Notifications
You must be signed in to change notification settings - Fork 20
datashader speedup and bugfixes #309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
50808b3
759ecd8
ea718bf
eef7b8b
702a226
420772a
48c1c52
fc89462
a0dac08
c8b0b34
e1662e8
d3a4c14
f0074c9
6dfaf51
3d890a6
a2b66e1
8b0b24d
982c627
0d124d7
ceb4fd2
59a19da
22cdabc
1d07544
febd424
2a95236
1d871ed
0bf8c35
4391b81
a06400b
4b52fe7
cd0f68b
214cb9d
ddc7927
73568ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| import matplotlib | ||
| import matplotlib.transforms as mtransforms | ||
| import numpy as np | ||
| import numpy.ma as ma | ||
| import pandas as pd | ||
| import scanpy as sc | ||
| import spatialdata as sd | ||
|
|
@@ -20,11 +21,12 @@ | |
| from matplotlib.colors import ListedColormap, Normalize | ||
| from scanpy._settings import settings as sc_settings | ||
| from spatialdata import get_extent | ||
| from spatialdata.models import PointsModel, get_table_keys | ||
| from spatialdata.models import Image2DModel, PointsModel, get_table_keys | ||
| from spatialdata.transformations import ( | ||
| get_transformation, | ||
| set_transformation, | ||
| ) | ||
| from spatialdata.transformations.transformations import Scale | ||
|
|
||
| from spatialdata_plot._logging import logger | ||
| from spatialdata_plot.pl.render_params import ( | ||
|
|
@@ -164,16 +166,29 @@ def _render_shapes( | |
| if method == "datashader": | ||
| trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData | ||
|
|
||
| extent = get_extent(sdata.shapes[element]) | ||
| x_ext = extent["x"][1] | ||
| y_ext = extent["y"][1] | ||
| x_range = [0, x_ext] | ||
| y_range = [0, y_ext] | ||
| # round because we need integers | ||
| plot_width = int(np.round(x_range[1] - x_range[0])) | ||
| plot_height = int(np.round(y_range[1] - y_range[0])) | ||
| extent = get_extent(sdata_filt.shapes[element], coordinate_system=coordinate_system) | ||
| x_ext = [min(0, extent["x"][0]), extent["x"][1]] | ||
| y_ext = [min(0, extent["y"][0]), extent["y"][1]] | ||
| previous_xlim = ax.get_xlim() | ||
| previous_ylim = ax.get_ylim() | ||
| # increase range if sth larger was rendered before | ||
| if _mpl_ax_contains_elements(ax): | ||
| x_ext = [min(x_ext[0], previous_xlim[0]), max(x_ext[1], previous_xlim[1])] | ||
| if ax.yaxis_inverted(): # case for e.g. images | ||
| y_ext = [min(y_ext[0], previous_ylim[1]), max(y_ext[1], previous_ylim[0])] | ||
| else: # case for e.g. labels | ||
| y_ext = [min(y_ext[0], previous_ylim[0]), max(y_ext[1], previous_ylim[1])] | ||
|
|
||
| # compute canvas size in pixels close to the actual image size to speed up computation | ||
| plot_width = x_ext[1] - x_ext[0] | ||
| plot_height = y_ext[1] - y_ext[0] | ||
| plot_width_px = int(round(fig_params.fig.get_size_inches()[0] * fig_params.fig.dpi)) | ||
| plot_height_px = int(round(fig_params.fig.get_size_inches()[1] * fig_params.fig.dpi)) | ||
| factor = np.min([plot_width / plot_width_px, plot_height / plot_height_px]) | ||
| plot_width = int(np.round(plot_width / factor)) | ||
| plot_height = int(np.round(plot_height / factor)) | ||
|
|
||
| cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_range, y_range=y_range) | ||
| cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext) | ||
|
|
||
| _geometry = shapes["geometry"] | ||
| is_point = _geometry.type == "Point" | ||
|
|
@@ -223,16 +238,44 @@ def _render_shapes( | |
| cmap=render_params.cmap_params.cmap, | ||
| ) | ||
| ) | ||
| rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2)) | ||
| _cax = ax.imshow(rgba_image, cmap=palette, zorder=render_params.zorder) | ||
| _cax.set_transform(trans) | ||
| cax = ax.add_image(_cax) | ||
|
|
||
| # create SpatialImage to get it back to original size | ||
| rgba_image = np.transpose(ds_result.to_numpy().base, (2, 0, 1)) | ||
| rgba_image = Image2DModel.parse( | ||
| rgba_image, | ||
| dims=("c", "y", "x"), | ||
| transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))}, | ||
| ) | ||
|
|
||
| # prepare transformation | ||
| trans = get_transformation(rgba_image, get_all=True)["global"] | ||
| affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) | ||
| trans = mtransforms.Affine2D(matrix=affine_trans) | ||
| trans_data = trans + ax.transData | ||
|
|
||
| rgba_image = np.transpose(rgba_image.data.compute(), (1, 2, 0)) # type: ignore[attr-defined] | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here (and in render_shapes), we access the image as numpy array from the SpatialImage. mypy doesn't believe that |
||
| rgba_image = ma.masked_array(rgba_image) # type conversion for mypy | ||
| _cax = _ax_show_and_transform( | ||
| rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.fill_alpha | ||
| ) | ||
|
|
||
| cax = None | ||
| if aggregate_with_sum is not None: | ||
| cax = ScalarMappable( | ||
| norm=matplotlib.colors.Normalize(vmin=aggregate_with_sum[0], vmax=aggregate_with_sum[1]), | ||
| cmap=render_params.cmap_params.cmap, | ||
| ) | ||
|
|
||
| # rgba_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2)) | ||
| # _cax = ax.imshow(rgba_image, cmap=palette, zorder=render_params.zorder) | ||
| # _cax.set_transform(trans) | ||
| # cax = ax.add_image(_cax) | ||
| # if aggregate_with_sum is not None: | ||
| # cax = ScalarMappable( | ||
| # norm=matplotlib.colors.Normalize(vmin=aggregate_with_sum[0], vmax=aggregate_with_sum[1]), | ||
| # cmap=render_params.cmap_params.cmap, | ||
| # ) | ||
|
|
||
| elif method == "matplotlib": | ||
| _cax = _get_collection_shape( | ||
| shapes=shapes, | ||
|
|
@@ -416,9 +459,15 @@ def _render_points( | |
| y_ext = [min(y_ext[0], previous_ylim[1]), max(y_ext[1], previous_ylim[0])] | ||
| else: # case for e.g. labels | ||
| y_ext = [min(y_ext[0], previous_ylim[0]), max(y_ext[1], previous_ylim[1])] | ||
| # round because we need integers | ||
| plot_width = int(np.round(x_ext[1] - x_ext[0])) | ||
| plot_height = int(np.round(y_ext[1] - y_ext[0])) | ||
|
|
||
| # compute canvas size in pixels close to the actual image size to speed up computation | ||
| plot_width = x_ext[1] - x_ext[0] | ||
| plot_height = y_ext[1] - y_ext[0] | ||
| plot_width_px = int(round(fig_params.fig.get_size_inches()[0] * fig_params.fig.dpi)) | ||
| plot_height_px = int(round(fig_params.fig.get_size_inches()[1] * fig_params.fig.dpi)) | ||
| factor = np.min([plot_width / plot_width_px, plot_height / plot_height_px]) | ||
| plot_width = int(np.round(plot_width / factor)) | ||
| plot_height = int(np.round(plot_height / factor)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd consider bundling this code in a private function since it's duplicate from above. |
||
|
|
||
| # use datashader for the visualization of points | ||
| cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext) | ||
|
|
@@ -456,8 +505,25 @@ def _render_points( | |
| cmap=render_params.cmap_params.cmap, | ||
| ) | ||
|
|
||
| rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2)) | ||
| cax = ax.imshow(rbga_image, zorder=render_params.zorder, alpha=render_params.alpha) | ||
| # create SpatialImage to get it back to original size | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also this code is a duplicate, I'd consider it refactoring it into a private function. |
||
| rgba_image = np.transpose(ds_result.to_numpy().base, (2, 0, 1)) | ||
| rgba_image = Image2DModel.parse( | ||
| rgba_image, | ||
| dims=("c", "y", "x"), | ||
| transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))}, | ||
| ) | ||
|
|
||
| # prepare transformation | ||
| trans = get_transformation(rgba_image, get_all=True)["global"] | ||
| affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) | ||
| trans = mtransforms.Affine2D(matrix=affine_trans) | ||
| trans_data = trans + ax.transData | ||
|
|
||
| rgba_image = np.transpose(rgba_image.data.compute(), (1, 2, 0)) # type: ignore[attr-defined] | ||
| rgba_image = ma.masked_array(rgba_image) # type conversion for mypy | ||
| _ax_show_and_transform(rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.alpha) | ||
|
|
||
| cax = None | ||
| if aggregate_with_sum is not None: | ||
| cax = ScalarMappable( | ||
| norm=matplotlib.colors.Normalize(vmin=aggregate_with_sum[0], vmax=aggregate_with_sum[1]), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.