torchtitan
torchtitan copied to clipboard
[Draft] PP tracing works
Purpose of this PR is to show:
- One line change needed -- remove this line:
self.freqs_cis = self.freqs_cis.to(h.device)
Reason 1: compile does not support in-place attribute mutation. Reason 2: not really a good practice to move device during forward (should better happen during init).
Or, if freqs_cis is a buffer, use register_buffer in init, then in-place mutation is okay.
- Show how to
annotate_split_points
for i in range(1, pp_degree):
annotate_split_points(model,
{f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING})