erichan1
erichan1
It's definitely possible that this is a BT bug. Thanks for flagging! Let me look into it.
Hi! Interested in int8 as well. To clarify for int8_mode=1 (weight-only int8), during inference are you casting int8 weights back to fp16, doing an fp16 linear, then deleting the fp16...
Just to confirm for GPT models, int8_mode=1 is the one shown here https://github.com/NVIDIA/FasterTransformer/blob/main/docs/images/workflow-of-int8-inference.png? But int8_mode=2 is switched to smoothquant in GPT models instead of the one shown in the picture?...
> (not the author, but) I think the graph was for int8 BERT. The two int8 modes in BERT are defined differently than those in GPT. There hasn't been a...
> `int8_mode = 1` is weight only quantization. As @erichan1 says, we quantize the weight to int8 during converting weight. During inference, we load int8 weight, dequantizing it to fp16...
Closing because stale
I believe the issue is you need to pack your input sequences together. ie instead of your input looking like (batch_size, max_seq_len, ...), it should be (batch_size * seq_len, ...),...
Unfortunately I don't believe we have one-gpu training times and it's difficult to extrapolate. So no good solution I'm afraid. @johntran-nv thoughts?
cc @janekl Any thoughts on this?
cc @janekl @samiwilf any thoughts here on the checkpointing issue? I can tag torchrec folks if this continues to be a blocker.