DeepSpeed
DeepSpeed copied to clipboard
only override forward if using cuda-graph
https://github.com/huggingface/transformers/pull/18261 introduces model arg validation, which is not compatible with how ds-inference was originally setup. We no longer need to do all of the things we previously did in an engine forward. This PR simplifies the inference engine a bit and also protects against a case where mp>1 and cuda graph are both enabled which is currently not compatible.