muon icon indicating copy to clipboard operation
muon copied to clipboard

Enable `gene_symbols` argument in mu.pl.embedding

Open racng opened this issue 2 years ago • 1 comments

Is your feature request related to a problem? Please describe. sc.pl.embedding takes an arugument gene_symbols that specifies which column in adata.var to look for the color keys. This argument does not work in mu.pl.embedding. Based on the source code, a new adata is made with the color values stored in adata.obs and so it has no way to access adata.var.

Describe the solution you'd like Redesign the way mu.pl.embedding generates the intermediate adata passed onto sc.pl.embedding. Perhaps use the sc.get.obs_df() function that also takes layers and gene_symbols as arguments?

Describe alternatives you've considered User could write their own function to consolidate the basis and color keys into one adata and use sc.pl.embedding.

racng avatar Aug 30 '23 22:08 racng

Here is a proposed solution that I tested:

def get_uns_colors(data: Union[AnnData, MuData], key: str):
    uns_key = key + '_colors'
    if uns_key in data.uns:
        return data.uns[uns_key]

def to_dtype_list(x, dtype, n, none=True):
    if not isinstance(dtype, Iterable) or isinstance(dtype, str):
        dtypes = [dtype]  
    if none:
        dtypes.append(type(None))
    if any([isinstance(x, t) for t in dtypes]):
        # Return as list of repeated value
        return [x] * n
    elif isinstance(x, Iterable):
        # Check types
        assert(all([any([isinstance(y, t) for t in dtypes]) for y in x]))
        # Check length
        assert(len(x) == n)
        # Return list unchanged
        return x


# Rewrite muon.pl.embedding to use gene_symbols
def embedding(
    data: MuData,
    basis: str,
    color: Optional[Union[str, Sequence[str]]] = None,
    layer: Optional[Union[str, Sequence[str]]] = None,
    gene_symbols: Optional[Union[str, Sequence[str]]] = None,
    use_raw: Optional[Union[bool, Sequence[bool]]] = False,
    **kwargs
):
    if isinstance(data, AnnData):
        return sc.pl.embedding(
            data, basis=basis, color=color, use_raw=use_raw, layer=layer, 
            gene_symbols=gene_symbols, **kwargs
        )
    if basis not in data.obsm:
        if "X_" + basis in data.obsm:
            basis = 'X_' + basis
            
    #  Determine basis
    if basis in data.obsm:
        adata = data
        basis_mod = basis
    else:
        try:
            mod, basis_mod = basis.split(":")
        except ValueError:
            raise ValueError(f"Basis {basis} is not present in the MuData object (.obsm)")
        
        if mod not in data.mod:
            raise ValueError(
                f"Modality {mod} is not present in the MuData object with modalities {', '.join(data.mod)}"
            )

        adata = data.mod[mod]
        if basis_mod not in adata.obsm:
            if "X_" + basis_mod in adata.obsm:
                basis_mod = "X_" + basis_mod
            elif len(adata.obsm) > 0:
                raise ValueError(
                    f"Basis {basis_mod} is not present in the modality {mod} with embeddings {', '.join(adata.obsm)}"
                )
            else:
                raise ValueError(
                    f"Basis {basis_mod} is not present in the modality {mod} with no embeddings"
                )
    
    # Subset joint obs to embedding observations
    obs = data.obs.loc[adata.obs.index.values].copy()

    if color is None:
        ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp)
        return sc.pl.embedding(ad, basis=basis_mod, **kwargs)

    # Some `color` has been provided
    if isinstance(color, str):
        keys = color = [color]
    elif isinstance(color, Iterable):
        keys = color
    else:
        raise TypeError("Expected color to be a string or an iterable.")
    
    # Convert keyword args to lists
    n = len(keys)
    ls = to_dtype_list(layer, str, n, none=True)
    gs = to_dtype_list(gene_symbols, str, n, none=True)
    rs = to_dtype_list(use_raw, bool, n, none=True)

    # Parse features
    mod2keys = {m: defaultdict(list) for m in data.mod.keys()}
    joint_keys = []
    uns = dict()
    for key, layer, gene_symbols, use_raw in zip(keys, ls, gs, rs):
        if key is None:
            joint_keys.append(key)
            continue

        # Key in joint obs
        if key in obs:
            joint_keys.append(key)

            # Look for color palette 
            palette = get_uns_colors(data, key)
            if palette is not None:
                uns[key + '_colors'] = palette
            continue
        
        # Key in modality
        try:
            mod, key_mod = key.split(":")

        except ValueError:
            raise ValueError(f"Key {key} is not present in the MuData object (.obs)")
        
        try:
            mod2keys[mod][(layer, gene_symbols, use_raw)].append(key_mod)
        except ValueError:
            raise ValueError(
                f"Modality {mod} is not present in the MuData object with modalities {', '.join(data.mod)}"
            )

        # Look for color palette 
        palette = get_uns_colors(data.mod[mod], key_mod)
        if palette is not None:
            uns[f"{mod}:{key_mod}_colors"] = palette
    
    # Add features for each modality to obs
    mod_keys = []
    for m in mod2keys:
        # Loop through unique combinations of args
        for args in mod2keys[m]:
            # Get features as dataframe
            layer, gene_symbols, use_raw = args
            df = sc.get.obs_df(data.mod[m], keys=mod2keys[m][args], layer=layer, 
                gene_symbols=gene_symbols, use_raw=use_raw)
            labels = []
            if use_raw: 
                labels.append('use_raw')
            if layer is not None:
                labels.append(layer)
            cond = '_'.join(labels)
            not_obs = [x not in data.mod[m].obs for x in mod2keys[m][args]]
            df.columns = f"{m}:" + df.columns 
            cols = df.columns.values
            cols[not_obs] += '\n' + cond
            df.columns = cols
            mod_keys += cols.tolist()
            # Merge with joint obs
            obs = obs.merge(df, left_index=True, right_index=True, how='left')

    # Plot
    ad = AnnData(obs=obs, obsm=adata.obsm, uns=uns)
    retval = sc.pl.embedding(ad, basis=basis_mod, color=joint_keys+mod_keys, **kwargs)
    
    # Update color palettes for joint keys
    for key in joint_keys:
        try:
            data.uns[f"{key}_colors"] = ad.uns[f"{key}_colors"]
        except KeyError:
            pass

    # Update color palettes for modality keys
    for m in mod2keys:
        for cond in mod2keys[m]:
            for key in mod2keys[m][cond]:
                try:
                    data.mod[m].uns[f"{key}_colors"] = ad.uns[f"{m}:{key}_colors"]
                except KeyError:
                    pass
    return retval

Example usuage:

sw.pl.embedding(mdata, 'rna:umap', [
	'prot:CD4', 'prot:CD4', 'rna:CD4', 'rna:sample'], 
	gene_symbols=['symbols', 'symbols', None, None], 
	layer=['raw', 'cellbender', None, None])

3077a023-96fe-4102-a263-d1489b13f9d1

Checking color palette updated:

mdata.uns
# Output:
# {'rna:sample_colors': ['#1f77b4', '#ff7f0e']}

There are some non-ideal behaviors that could be fixed if needed but it doesn't affect the functionality:

  • The order of color keys plotted are reordered and grouped by modalites
  • The color palette for a modality-specific categorial variable mod:key is added to mdata.uns['mod:key_colors'] instead of mdata[mod].uns['key_colors']. This is because mod:key could be found in mdata.obs and so it was treated as a joint obs.

racng avatar Aug 31 '23 02:08 racng