How to run it with multi device?
./RWKU/KnowledgeCircuits-main/KnowledgeCircuits-main/transformer_lens/components.py:625, in AbstractAttention.forward(self, query_input, key_input, value_input, past_kv_cache_entry, additive_attention_mask, attention_mask)
616 result = self.hook_result(
617 bnb.matmul_4bit(
618 z.reshape(z.shape[0], z.shape[1], self.cfg.d_model),
(...)
622 )
623 )
624 else:
--> 625 result = self.hook_result(
626 einsum(
627 "batch pos head_index d_head, \
628 head_index d_head d_model -> \
629 batch pos head_index d_model",
630 z,
631 self.W_O,
632 )
633 )
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
If I set n_devide=2 in HookedTransformer.from_pretrained(model_name=LLAMA_2_7B_CHAT_PATH, device="cuda",n_devices=2, fold_ln=False, center_writing_weights=False, center_unembed=False) I will get the above errors.
It currently does not support multi-gpu and I will add this feature recently. Will let you know when I tackle it. You can try to move the tensor to the same device temporarily. Thanks!
Thank you for your interest! I hope this message finds you well.
Memory Requirements for Using LLaMA2-7B-Chat Model:
- Running on a single GPU requires 57,116M of memory.
- Running the
knowledge_eap.ipynbfile takes approximately 3-4 minutes.
Regarding Multi-GPU Support:
- Since we are based on TransformerLens, there are some issues in multi-GPU environments.
- Therefore, multi-GPU support is currently not available.
Future Optimizations:
- If the opportunity arises, we may further optimize this issue.
Suggestions for User Custom Modifications:
- If you are interested, you can modify TransformerLens yourself.
- A simple approach is to use the
tomethod to move tensors to the same device wherever you see the error "Expected all tensors to be on the same device, but found at least two devices."
Example Code:
result = self.hook_result(
einsum(
"batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos head_index d_model",
z.to(self.W_O.device),
self.W_O,
)
)
Thank you for your understanding and support! If you manage to resolve the issue, feel free to submit a pull request. Your contributions are always welcome!
hi, do you have any further issue?