ColabFold icon indicating copy to clipboard operation
ColabFold copied to clipboard

Code for coloring by PAE

Open seanrjohnson opened this issue 3 years ago • 0 comments

Here is some code for coloring the structures based on PAE, which I think is useful for visualizing domain boundaries.

Maybe others will also find it useful.

The legend code is not great, it would be better to have a gradient drawing.

Not sure if it's in the scope of what you want to do with this view in ColabFold, but mouseovers to get residue id would also be nice.

Also, I realized I haven't tested this with multimers. I'll do that and update if needed.

#@title Display 3D structure {run: "auto"}
import py3Dmol
import glob
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import to_hex
from colabfold.colabfold import plot_plddt_legend
import json

def plot_pae_legend(dpi=100):
  thresh = ['PAE:','0','15', '30']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["#FFFFFF","#0000ff","#fffefe","#ff0000"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt



rank_num = 1 #@param ["1", "2", "3", "4", "5"] {type:"raw"}
color = "PAE" #@param ["chain", "lDDT", "rainbow", "PAE"]
PAE_position =  20#@param {type:"integer"}
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

MAX_PAE = 30.0

PAE_position = PAE_position - 1

pae_cmap = get_cmap("bwr")

jobname_prefix = ".custom" if msa_mode == "custom" else ""
if use_amber:
  pdb_filename = f"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb"
else:
  pdb_filename = f"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb"



pdb_file = glob.glob(pdb_filename)

def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color="lDDT", PAE_position=0):
  model_name = f"rank_{rank_num}"
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(pdb_file[0],'r').read(),'pdb')


  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "PAE":
    pae_scores = json.load(open(pdb_file[0][:-4]+"_scores.json",'r'))["pae"]
    if (PAE_position < 0):
      PAE_position = 0
    if PAE_position > len(pae_scores)-1:
      PAE_position = len(pae_scores)-1
    
    PAE_row = pae_scores[PAE_position]
    for r in range(len(PAE_row)):
      view.setStyle({'resi':f"{r+1}"},{'cartoon':{'color':to_hex(pae_cmap(PAE_row[r] / MAX_PAE))}})
    
  elif color == "chain":
    chains = len(queries[0][1]) + 1 if is_complex else 1
    for n,chain,color in zip(range(chains),list("ABCDEFGH"),
                     ["lime","cyan","magenta","yellow","salmon","white","blue","orange"]):
      view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view


show_pdb(rank_num,show_sidechains, show_mainchains, color, PAE_position).show()
if color == "lDDT":
  plot_plddt_legend().show() 
elif color == "PAE":
  plot_pae_legend().show()

seanrjohnson avatar Jul 09 '22 03:07 seanrjohnson