Ability to use discrete colormaps in `mne.viz.plot_source_estimates` without color interpolation
Describe the new feature or enhancement
We are using mne to plot single-source points with a discrete colormap to show "which model fits best here" results. We have found that when we supply discrete colormaps, the plotted results end up with a little "halo" of false colour around the edges of of the vertices which should be coloured:
I believe the reason for this is that instead of referencing into the colormap directly to get the colors to plot, mne creates a LUT from 256 sample colours and then the renderer interpolates between them.
I acknowledge that in most cases, where continuous values are plotted using a continuous colormap, interpolated LUTs are a great solution, but I believe they are preventing us from controlling our colors precisely.
I had two questions:
- Is there a specific reason, other than performance, why this approach is required?
- If we worked on a PR which allowed a specific Colormap object to be queried to plot all colours (thereby allowing us to precisely control color output), would you consider it?
Describe your proposed implementation
My initial approach for the PR would be to create an object which behaves like the LUT externally, but which internally queries the supplied Colormap directly.
Describe possible alternatives
The only other option I can think of would be to increase the number of samples in the LUT to greater than 256. However this doesn't seem like it solves the issue "once and for all", unless the above proposed solution would be insufficiently performant.
Additional context
More concretely, in our specific example we have created a forced-discrete colormap like this:
class DiscreteListedColormap(ListedColormap):
"""Like ListedColormap, but without interpolation between values."""
def __init__(self, colors: list, name = 'from_list', N = None, scale01: bool = False):
"""
Args:
scale01 (bool): True if the values will be supplied to the colormap in the range [0, 1] instead of the range
[0, N-1].
"""
self.scale01: bool = scale01
super().__init__(colors=colors, name=name, N=N)
def __call__(self, X, *args, **kwargs):
if self.scale01:
# Values are supplied between 0 and 1, so map them up to their corresponding index (or close to it)
X *= self.N
rounded = np.round(X).astype(int)
return super().__call__(X=rounded, *args, **kwargs)
But then by sampling this colormap 256 times and interpolating the results when plotting, we lose our ability to control which precise values correspond to which precise colours.
I believe the reason for this is that instead of referencing into the colormap directly to get the colors to plot, mne creates a LUT from 256 sample colours and then the renderer interpolates between them.
To me it looks like it might actually be a problem with smoothing_steps... if your STC comes from a source space that has been decimated (e.g., < 10000 points rather than > 100000) then smoothing_steps will smear the data to neighboring vertices. If you originally just had integers for example, your smoothed data will not. Usually you can work around this sort of thing.
For example, one time I wanted to plot on fsaverage the number of subjects between, say, 0 and 16 for which each vertex showed some effect. So I wrote a little helper function cmap, clim = discretize_cmap("viridis", [0, 16]) to upsample the cmap to 256 values while preserving discretization:
discretize_cmap
def discretize_cmap(colormap, lims, transparent=True):
"""Discretize a colormap."""
lims = np.array(lims, int)
assert lims.shape == (2,)
from matplotlib import colors, pyplot as plt
n_pts = lims[1] - lims[0] + 1
assert n_pts > 0
if n_pts == 1:
vals = np.ones(256)
else:
vals = np.round(np.linspace(-0.5, n_pts - 0.5, 256)) / (n_pts - 1)
colormap = plt.get_cmap(colormap)(vals)
if transparent:
colormap[:, 3] = np.clip((vals + 0.5 / n_pts) * 2, 0, 1)
colormap[0, 3] = 0.
colormap = colors.ListedColormap(colormap)
use_lims = [lims[0] - 0.5, (lims[0] + lims[1]) / 2., lims[1] + 0.5]
return colormap, use_lims
It would return the cmap, clim pair I could pass to mne.viz.Brain to get it to map the values appropriately. The short version / trick was to make the clim=dict(kind="values", lims=[-0.5, 16.5]) and construct the LUT correspondingly. Then with smoothing_steps smeared my values continuously between 0 and 16, anything between 0 and 0.5 would be the first color, anything between 0.5 and 1.5 would be the second color, and so forth. Maybe it would work for your use case? Maybe we can/should add this to some example so people could see how LUT tricks like this could be used to solve problems?
One minor thing, I'm not sure why you need DiscreteListedColormap as I think ListedColormap already has a "nearest" behavior:
>>> ListedColormap(["r", "g", "b"])(np.linspace(0, 1, 10))
array([[1. , 0. , 0. , 1. ],
[1. , 0. , 0. , 1. ],
[1. , 0. , 0. , 1. ],
[0. , 0.5, 0. , 1. ],
[0. , 0.5, 0. , 1. ],
[0. , 0.5, 0. , 1. ],
[0. , 0. , 1. , 1. ],
[0. , 0. , 1. , 1. ],
[0. , 0. , 1. , 1. ],
[0. , 0. , 1. , 1. ]])
Many thanks for the reply @larsoner!
Your suggested workaround was useful in getting us to a solution.
I'm not sure why you need DiscreteListedColormap as I think ListedColormap already has a "nearest" behavior
It was actually trying to get around ListedColormap's weird dual float/int callable interface...
from matplotlib.colors import ListedColormap
cmap = ListedColormap(list("rgb"))
print(cmap(0.5) == cmap(1)) # True
print(cmap(1) == cmap(1.0)) # False
but no longer needed!
FWIW I would have found an example super useful, and I imagine that plotting discrete values like subject counts or best-model indices is not uncommon!