mne-python icon indicating copy to clipboard operation
mne-python copied to clipboard

Ability to use discrete colormaps in `mne.viz.plot_source_estimates` without color interpolation

Open caiw opened this issue 10 months ago • 3 comments

Describe the new feature or enhancement

We are using mne to plot single-source points with a discrete colormap to show "which model fits best here" results. We have found that when we supply discrete colormaps, the plotted results end up with a little "halo" of false colour around the edges of of the vertices which should be coloured:

Image

I believe the reason for this is that instead of referencing into the colormap directly to get the colors to plot, mne creates a LUT from 256 sample colours and then the renderer interpolates between them.

I acknowledge that in most cases, where continuous values are plotted using a continuous colormap, interpolated LUTs are a great solution, but I believe they are preventing us from controlling our colors precisely.

I had two questions:

  1. Is there a specific reason, other than performance, why this approach is required?
  2. If we worked on a PR which allowed a specific Colormap object to be queried to plot all colours (thereby allowing us to precisely control color output), would you consider it?

Describe your proposed implementation

My initial approach for the PR would be to create an object which behaves like the LUT externally, but which internally queries the supplied Colormap directly.

Describe possible alternatives

The only other option I can think of would be to increase the number of samples in the LUT to greater than 256. However this doesn't seem like it solves the issue "once and for all", unless the above proposed solution would be insufficiently performant.

Additional context

More concretely, in our specific example we have created a forced-discrete colormap like this:

class DiscreteListedColormap(ListedColormap):
    """Like ListedColormap, but without interpolation between values."""
    def __init__(self, colors: list, name = 'from_list', N = None, scale01: bool = False):
        """
        Args:
            scale01 (bool): True if the values will be supplied to the colormap in the range [0, 1] instead of the range
                [0, N-1].
        """
        self.scale01: bool = scale01
        super().__init__(colors=colors, name=name, N=N)

    def __call__(self, X, *args, **kwargs):
        if self.scale01:
            # Values are supplied between 0 and 1, so map them up to their corresponding index (or close to it)
            X *= self.N
        rounded = np.round(X).astype(int)
        return super().__call__(X=rounded, *args, **kwargs)

But then by sampling this colormap 256 times and interpolating the results when plotting, we lose our ability to control which precise values correspond to which precise colours.

caiw avatar Apr 04 '25 14:04 caiw

I believe the reason for this is that instead of referencing into the colormap directly to get the colors to plot, mne creates a LUT from 256 sample colours and then the renderer interpolates between them.

To me it looks like it might actually be a problem with smoothing_steps... if your STC comes from a source space that has been decimated (e.g., < 10000 points rather than > 100000) then smoothing_steps will smear the data to neighboring vertices. If you originally just had integers for example, your smoothed data will not. Usually you can work around this sort of thing.

For example, one time I wanted to plot on fsaverage the number of subjects between, say, 0 and 16 for which each vertex showed some effect. So I wrote a little helper function cmap, clim = discretize_cmap("viridis", [0, 16]) to upsample the cmap to 256 values while preserving discretization:

discretize_cmap
def discretize_cmap(colormap, lims, transparent=True):
    """Discretize a colormap."""
    lims = np.array(lims, int)
    assert lims.shape == (2,)
    from matplotlib import colors, pyplot as plt
    n_pts = lims[1] - lims[0] + 1
    assert n_pts > 0
    if n_pts == 1:
        vals = np.ones(256)
    else:
        vals = np.round(np.linspace(-0.5, n_pts - 0.5, 256)) / (n_pts - 1)
    colormap = plt.get_cmap(colormap)(vals)
    if transparent:
        colormap[:, 3] = np.clip((vals + 0.5 / n_pts) * 2, 0, 1)
    colormap[0, 3] = 0.
    colormap = colors.ListedColormap(colormap)
    use_lims = [lims[0] - 0.5, (lims[0] + lims[1]) / 2., lims[1] + 0.5]
    return colormap, use_lims

It would return the cmap, clim pair I could pass to mne.viz.Brain to get it to map the values appropriately. The short version / trick was to make the clim=dict(kind="values", lims=[-0.5, 16.5]) and construct the LUT correspondingly. Then with smoothing_steps smeared my values continuously between 0 and 16, anything between 0 and 0.5 would be the first color, anything between 0.5 and 1.5 would be the second color, and so forth. Maybe it would work for your use case? Maybe we can/should add this to some example so people could see how LUT tricks like this could be used to solve problems?

One minor thing, I'm not sure why you need DiscreteListedColormap as I think ListedColormap already has a "nearest" behavior:

>>> ListedColormap(["r", "g", "b"])(np.linspace(0, 1, 10))
array([[1. , 0. , 0. , 1. ],
       [1. , 0. , 0. , 1. ],
       [1. , 0. , 0. , 1. ],
       [0. , 0.5, 0. , 1. ],
       [0. , 0.5, 0. , 1. ],
       [0. , 0.5, 0. , 1. ],
       [0. , 0. , 1. , 1. ],
       [0. , 0. , 1. , 1. ],
       [0. , 0. , 1. , 1. ],
       [0. , 0. , 1. , 1. ]])

larsoner avatar Apr 04 '25 17:04 larsoner

Many thanks for the reply @larsoner!

Your suggested workaround was useful in getting us to a solution.

I'm not sure why you need DiscreteListedColormap as I think ListedColormap already has a "nearest" behavior

It was actually trying to get around ListedColormap's weird dual float/int callable interface...

from matplotlib.colors import ListedColormap

cmap = ListedColormap(list("rgb"))

print(cmap(0.5) == cmap(1))  # True
print(cmap(1) == cmap(1.0))  # False

but no longer needed!

caiw avatar Apr 11 '25 16:04 caiw

FWIW I would have found an example super useful, and I imagine that plotting discrete values like subject counts or best-model indices is not uncommon!

caiw avatar Apr 11 '25 16:04 caiw