spatialdata-plot icon indicating copy to clipboard operation
spatialdata-plot copied to clipboard

Points not transformed when `method="datashader"`

Open clwgg opened this issue 1 year ago • 3 comments

As per the title, I just ran into a case where datashader was chosen as the method for render_points, which led to my points being plotted without the relevant transformation being applied. I stole the example from https://github.com/scverse/spatialdata-plot/issues/182 for testing below.

from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel
from spatialdata.transformations import Scale
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import spatialdata_plot

sdata = SpatialData(
    images={
        "image1": Image2DModel.parse(
            np.full((10, 10, 3), fill_value=128), dims=("y", "x", "c")
        )
    },
    points={
        "points1": PointsModel.parse(
            pd.DataFrame({"y": [0.1, 0.1, 0.9, 0.9], "x": [0.1, 0.9, 0.9, 0.1]}),
            transformations={"global": Scale([10, 10], ("y", "x"))},
        )
    },
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
sdata.pl.render_images("image1").pl.render_points("points1", method="datashader").pl.show(ax=ax1, title="datashader")
sdata.pl.render_images("image1").pl.render_points("points1", method="matplotlib").pl.show(ax=ax2, title="matplotlib")

With current main: Screenshot 2024-08-28 at 10 01 40 PM

With https://github.com/scverse/spatialdata-plot/pull/309: Screenshot 2024-08-28 at 10 02 09 PM

clwgg avatar Aug 29 '24 05:08 clwgg

Thanks for reporting. Please see the discussion on this issue also here: https://github.com/scverse/spatialdata-plot/issues/291.

LucaMarconato avatar Sep 08 '24 10:09 LucaMarconato

I think I'm seeing a consequence of that in my own data. Calling

(
    sdata_cropped
    .pl.render_points(TRANSCRIPT_KEY, size=1, color="red", method="matplotlib")
    .pl.show()
)

works just fine, but when using method="datashader", I get

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[45], line 4
      1 (
      2     sdata_cropped
      3     .pl.render_points(TRANSCRIPT_KEY, size=1, color="red", method="datashader")
----> 4     .pl.show()
      5 )

File [/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py:895](http://localhost:50163/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py#line=894), in PlotAccessor.show(self, coordinate_systems, legend_fontsize, legend_fontweight, legend_loc, legend_fontoutline, na_in_legend, colorbar, wspace, hspace, ncols, frameon, figsize, dpi, fig, title, share_extent, pad_extent, ax, return_ax, save)
    890     wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements(
    891         sdata, wanted_elements, params_copy, cs, "points"
    892     )
    894     if wanted_points_on_this_cs:
--> 895         _render_points(
    896             sdata=sdata,
    897             render_params=params_copy,
    898             coordinate_system=cs,
    899             ax=ax,
    900             fig_params=fig_params,
    901             scalebar_params=scalebar_params,
    902             legend_params=legend_params,
    903         )
    905 elif cmd == "render_labels" and has_labels:
    906     wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements(
    907         sdata, wanted_elements, params_copy, cs, "labels"
    908     )

File [/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py:483](http://localhost:50163/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py#line=482), in _render_points(sdata, render_params, coordinate_system, ax, fig_params, scalebar_params, legend_params)
    466     color_vector = np.asarray([x[:-2] for x in color_vector])
    468 ds_result = (
    469     ds.tf.shade(
    470         ds.tf.spread(agg, px=px),
   (...)
    481     )
    482 )
--> 483 rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
    484 cax = ax.imshow(rbga_image, zorder=render_params.zorder, alpha=render_params.alpha)
    485 if aggregate_with_sum is not None:

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:655](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=654), in transpose(a, axes)
    588 @array_function_dispatch(_transpose_dispatcher)
    589 def transpose(a, axes=None):
    590     """
    591     Returns an array with axes transposed.
    592 
   (...)
    653 
    654     """
--> 655     return _wrapfunc(a, 'transpose', axes)

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:56](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=55), in _wrapfunc(obj, method, *args, **kwds)
     54 bound = getattr(obj, method, None)
     55 if bound is None:
---> 56     return _wrapit(obj, method, *args, **kwds)
     58 try:
     59     return bound(*args, **kwds)

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:45](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=44), in _wrapit(obj, method, *args, **kwds)
     43 except AttributeError:
     44     wrap = None
---> 45 result = getattr(asarray(obj), method)(*args, **kwds)
     46 if wrap:
     47     if not isinstance(result, mu.ndarray):

ValueError: axes don't match array

Marius1311 avatar Sep 13 '24 10:09 Marius1311

@Marius1311 thanks for reporting. How did you construct sdata_cropped? It would be helpful for us if you could please reproduce your bug using the blobs dataset.

You can access it via one of these two functions:

  • https://spatialdata.scverse.org/en/latest/generated/spatialdata.datasets.blobs.html
  • https://spatialdata.scverse.org/en/latest/generated/spatialdata.datasets.blobs_annotating_element.html

CC @melonora

LucaMarconato avatar Sep 30 '24 11:09 LucaMarconato

@clwgg Thanks for reporting! I reproduced the problem without the image in the background which led to the points being shifted by 0.5 when using datashader (because of #216).

from spatialdata import SpatialData
from spatialdata.models import PointsModel
from spatialdata.transformations import Scale

sdata = SpatialData(
    points={
        "points1": PointsModel.parse(
            pd.DataFrame({"y": [0, 0, 10, 10, 4, 6, 4, 6], "x": [0, 10, 10, 0, 4, 6, 6, 4]}),
            transformations={"global": Scale([2, 2], ("y", "x"))},
        )
    },
)
sdata.pl.render_points("points1", method="matplotlib", size=50, color="lightgrey").pl.render_points("points1", method="datashader", size=10, color="red").pl.show()

With this, I get a) before:

Image

b) after my fix (#378): Image

Sonja-Stockhaus avatar Oct 25 '24 12:10 Sonja-Stockhaus

@clwgg could you verify that Sonja's branch fixes the issue for you as well? :) Thanks!

timtreis avatar Oct 30 '24 14:10 timtreis