PainlessInferenceAcceleration icon indicating copy to clipboard operation
PainlessInferenceAcceleration copied to clipboard

Counting how many forward passes/steps were done when using PAIN

Open jivanph opened this issue 2 years ago • 4 comments

I wanted to ask if there's a way to count how many forward passes/steps are done when using PAIN, to contrast it with standard decoding.

jivanph avatar Jan 30 '24 12:01 jivanph

On a different note, what are the parameters for the tree object? How many branches are made and how deep does the tree go?

jivanph avatar Jan 30 '24 14:01 jivanph

You can count the steps with two methods, one is turning on the debug_lookahead, it will output debug info of each step and you can count the steps manually, the other is turning on return_dict_in_generate in model.generation method, the kwargs of outputs will output decoding summary, len(kwargs['dls']) is step count.

We use different parameters for different tasks. As methoned in the readme of out repo, we use decoding_length=128 (i.e., forward token count) and branch_length=32 (i.e., tree depth) for RAG tasks and decoding_length=64 and branch_length=8 for dolly and GSM8K tasks. We do not use the branch count parameter as we care more about factual token count in a forward pass rather than logical branches.

zheyishine avatar Feb 01 '24 02:02 zheyishine

Thank you so much for your response. This helped me greatly. If I understand correctly, if I want to count how many draft token in total were used when using PAIN, I could just compute sum(kwargs['dls'])

jivanph avatar Feb 01 '24 15:02 jivanph

Should be sum(kwargs['dls'])-len(kwargs['dls']), because the decoding_length(i.e., dls) is compose of the next token and draft tokens, we should minus 1.

zheyishine avatar Feb 02 '24 05:02 zheyishine