Points not transformed when `method="datashader"`
As per the title, I just ran into a case where datashader was chosen as the method for render_points, which led to my points being plotted without the relevant transformation being applied. I stole the example from https://github.com/scverse/spatialdata-plot/issues/182 for testing below.
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel
from spatialdata.transformations import Scale
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import spatialdata_plot
sdata = SpatialData(
images={
"image1": Image2DModel.parse(
np.full((10, 10, 3), fill_value=128), dims=("y", "x", "c")
)
},
points={
"points1": PointsModel.parse(
pd.DataFrame({"y": [0.1, 0.1, 0.9, 0.9], "x": [0.1, 0.9, 0.9, 0.1]}),
transformations={"global": Scale([10, 10], ("y", "x"))},
)
},
)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
sdata.pl.render_images("image1").pl.render_points("points1", method="datashader").pl.show(ax=ax1, title="datashader")
sdata.pl.render_images("image1").pl.render_points("points1", method="matplotlib").pl.show(ax=ax2, title="matplotlib")
With current main:
With https://github.com/scverse/spatialdata-plot/pull/309:
Thanks for reporting. Please see the discussion on this issue also here: https://github.com/scverse/spatialdata-plot/issues/291.
I think I'm seeing a consequence of that in my own data. Calling
(
sdata_cropped
.pl.render_points(TRANSCRIPT_KEY, size=1, color="red", method="matplotlib")
.pl.show()
)
works just fine, but when using method="datashader", I get
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[45], line 4
1 (
2 sdata_cropped
3 .pl.render_points(TRANSCRIPT_KEY, size=1, color="red", method="datashader")
----> 4 .pl.show()
5 )
File [/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py:895](http://localhost:50163/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py#line=894), in PlotAccessor.show(self, coordinate_systems, legend_fontsize, legend_fontweight, legend_loc, legend_fontoutline, na_in_legend, colorbar, wspace, hspace, ncols, frameon, figsize, dpi, fig, title, share_extent, pad_extent, ax, return_ax, save)
890 wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements(
891 sdata, wanted_elements, params_copy, cs, "points"
892 )
894 if wanted_points_on_this_cs:
--> 895 _render_points(
896 sdata=sdata,
897 render_params=params_copy,
898 coordinate_system=cs,
899 ax=ax,
900 fig_params=fig_params,
901 scalebar_params=scalebar_params,
902 legend_params=legend_params,
903 )
905 elif cmd == "render_labels" and has_labels:
906 wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements(
907 sdata, wanted_elements, params_copy, cs, "labels"
908 )
File [/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py:483](http://localhost:50163/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py#line=482), in _render_points(sdata, render_params, coordinate_system, ax, fig_params, scalebar_params, legend_params)
466 color_vector = np.asarray([x[:-2] for x in color_vector])
468 ds_result = (
469 ds.tf.shade(
470 ds.tf.spread(agg, px=px),
(...)
481 )
482 )
--> 483 rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
484 cax = ax.imshow(rbga_image, zorder=render_params.zorder, alpha=render_params.alpha)
485 if aggregate_with_sum is not None:
File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:655](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=654), in transpose(a, axes)
588 @array_function_dispatch(_transpose_dispatcher)
589 def transpose(a, axes=None):
590 """
591 Returns an array with axes transposed.
592
(...)
653
654 """
--> 655 return _wrapfunc(a, 'transpose', axes)
File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:56](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=55), in _wrapfunc(obj, method, *args, **kwds)
54 bound = getattr(obj, method, None)
55 if bound is None:
---> 56 return _wrapit(obj, method, *args, **kwds)
58 try:
59 return bound(*args, **kwds)
File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:45](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=44), in _wrapit(obj, method, *args, **kwds)
43 except AttributeError:
44 wrap = None
---> 45 result = getattr(asarray(obj), method)(*args, **kwds)
46 if wrap:
47 if not isinstance(result, mu.ndarray):
ValueError: axes don't match array
@Marius1311 thanks for reporting. How did you construct sdata_cropped? It would be helpful for us if you could please reproduce your bug using the blobs dataset.
You can access it via one of these two functions:
- https://spatialdata.scverse.org/en/latest/generated/spatialdata.datasets.blobs.html
- https://spatialdata.scverse.org/en/latest/generated/spatialdata.datasets.blobs_annotating_element.html
CC @melonora
@clwgg Thanks for reporting! I reproduced the problem without the image in the background which led to the points being shifted by 0.5 when using datashader (because of #216).
from spatialdata import SpatialData
from spatialdata.models import PointsModel
from spatialdata.transformations import Scale
sdata = SpatialData(
points={
"points1": PointsModel.parse(
pd.DataFrame({"y": [0, 0, 10, 10, 4, 6, 4, 6], "x": [0, 10, 10, 0, 4, 6, 6, 4]}),
transformations={"global": Scale([2, 2], ("y", "x"))},
)
},
)
sdata.pl.render_points("points1", method="matplotlib", size=50, color="lightgrey").pl.render_points("points1", method="datashader", size=10, color="red").pl.show()
With this, I get a) before:
b) after my fix (#378):
@clwgg could you verify that Sonja's branch fixes the issue for you as well? :) Thanks!