erichan1

Results 13 comments of erichan1

It's definitely possible that this is a BT bug. Thanks for flagging! Let me look into it.

Hi! Interested in int8 as well. To clarify for int8_mode=1 (weight-only int8), during inference are you casting int8 weights back to fp16, doing an fp16 linear, then deleting the fp16...

Just to confirm for GPT models, int8_mode=1 is the one shown here https://github.com/NVIDIA/FasterTransformer/blob/main/docs/images/workflow-of-int8-inference.png? But int8_mode=2 is switched to smoothquant in GPT models instead of the one shown in the picture?...

> (not the author, but) I think the graph was for int8 BERT. The two int8 modes in BERT are defined differently than those in GPT. There hasn't been a...

> `int8_mode = 1` is weight only quantization. As @erichan1 says, we quantize the weight to int8 during converting weight. During inference, we load int8 weight, dequantizing it to fp16...

I believe the issue is you need to pack your input sequences together. ie instead of your input looking like (batch_size, max_seq_len, ...), it should be (batch_size * seq_len, ...),...

Unfortunately I don't believe we have one-gpu training times and it's difficult to extrapolate. So no good solution I'm afraid. @johntran-nv thoughts?

cc @janekl @samiwilf any thoughts here on the checkpointing issue? I can tag torchrec folks if this continues to be a blocker.