Checks for user injection policy
This PR introduces check for user injection policy arguments. Right now, below code executes without any error/warning even though user injection policy is not used.
from transformers import pipeline
import deepspeed
pythia_pipe = pipeline('text-generation', 'EleutherAI/pythia-160m-deduped', device=0)
injection_policy = {"blah blah": "some more blah blah"}
pythia_pipe.model = deepspeed.init_inference(pythia_pipe.model, replace_with_kernel_inject=False, injection_policy=injection_policy ,dtype=torch.half, enable_cuda_graph=False)
prompt = "Pythia combines interpretability analysis and"
response = pythia_pipe(prompt, max_new_tokens=50, return_full_text=False, do_sample=False)
This PR enforces correct usage of injection_policy with necessary checks.
Hey @lekurile, If you have some time can you please review this.
Hey @lekurile, If you have some time can you please review this.
Hi @satpalsr,
Thank you for the contribution! I left a review comment about moving the check out of replace_transformer_layer(). Could you please take a look?
Thanks, Lev
Hey @lekurile,
Thanks for checking, but I somehow can't find your review. As I understand you are suggesting to move the function out. Will make changes. Thanks.
Hey @lekurile,
Thanks for checking, but I somehow can't find your review. As I understand you are suggesting to move the function out. Will make changes. Thanks.
Hi @satpalsr,
Sorry, didn't submit the review earlier. Should be visible now.
Thanks, Lev
I see formatting failing. Though there's no formatting issues in my changes. I see issue in general_kernals.cu Should I still change that here?
Hi @satpalsr,
Apologies for the delayed response. I think we'd still like to merge this change, but it would be nice to merge the latest master and rerun tests again. I'll take another look at the formatting issues as well.
Thanks, Lev