Some discussion on attention score visualization on Pi0 model and RoboTwin benchmark
Pi0 is great work and I have seen great results on real devices. I want to explore the model attention and visualize it to give me a better understanding of Pi0. I used the simulation evaluation environment provided by RoboTwin to test and returned the attention score in Pi0. I saw some nice results, but maybe my parameters were not adjusted correctly and the focus was not focused well, but scattered around the paddings. Taking the task of grabbing a bottle as an example, here are some of my visualization results:
Here is some code where I process the attention score matrix:
pred_actions, attn_maps = model.get_action()
from openpi_client import image_tools
image_size = 224
base_img = {}
for cam_name in ['head_cam', 'left_cam', 'right_cam']:
img_chw = obs[cam_name] # (3,240,320)
img_hwc = np.transpose(img_chw, (1,2,0)) # (240,320,3)
img_hwc_uint8 = (img_hwc * 255).astype(np.uint8)
img = image_tools.resize_with_pad(img_hwc_uint8, image_size, image_size)
base_img[cam_name] = img
heatmaps = self.get_aggregated_heatmaps(
attn_maps = attn_maps,
timesteps = list(range(cfg["attention"]['timesteps']['start'],
cfg["attention"]['timesteps']['end'])),
layers = list(range(cfg["attention"]['layers']['start'],
cfg["attention"]['layers']['end'])),
heads = list(range(cfg["attention"]['heads']['start'],
cfg["attention"]['heads']['end'])),
query_indices = list(range(cfg["attention"]['query_indices']['start'],
cfg["attention"]['query_indices']['end'])),
patch_grid = 16,
image_count = 3,
)
self.overlay_heatmaps_on_images(
base_img = base_img,
heatmaps = heatmaps,
cam_names = ['head_cam','left_cam','right_cam'],
image_size = 224,
colormap = 'jet',
alpha = 1.0,
save_dir = '~/RoboTwin/attn_visual/' + str(args['task_name']) + '_' + str(time_stamp),
step_cnt = step_cnt,
)
Where attn_maps is returned from gemma.py layer by layer. The shape of attn_maps is (10, 18, 1, 8, 51, 867), which represent the diffusion time step, number of layers, batch, head, suffix and all tokens respectively. The original probs is this:
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
Does anyone have similar experience in visualizing Pi0 attention scores? What are the problems with my processing? And how should the number of visualization layers be adjusted? I hope someone can help me, I will be very grateful!!
你好,我刚刚跑通pi0模型,想在robotwin这个平台进行评测,请问您做的这个流程是怎么样的,可以把流程说一下吗?谢谢
This is cool! I have no experience visualizing the attention matrices -- one thing I'd suggest is to maybe look at individual matrices rather than aggregates to see whether any of the individual matrices (eg for specific time steps / layers / attention heads) look more informative
Hi, were you able to get the heatmap visualization working?