Medusa
Medusa copied to clipboard
Are Medusa Heads computed in parallel or serially?
Hello authors,
While reading your code, I noticed that the multiple Medusa Heads you proposed are computing results in parallel
for i in range(self.medusa):
medusa_logits.append(self.medusa_head[i](hidden_states))
(although the later Heads don't use the results from the previous Heads, the results are obtained using a for loop).
I'm wondering if I've misunderstood this, or if Medusa is currently using serially obtained results?
Could you please clarify this for me?