DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

get_fp32_state_dict_from_zero_checkpoint with fixed weights in pipeline

Open DA-L3 opened this issue 3 years ago • 2 comments

Hello,

I have a pipeline that looks as follows: Input --> Part A --> Part B --> output

What I want to achieve now is that I first set A not trainable (requires_grad=False) and only train B. Then after that part of the training, I would like to resume from that state (where B was trained) and not also make A trainable. For that I would like to use the get_fp32_state_dict_from_zero_checkpoint method which outputs me a state dict. Unfortunately it seems like only the weights for B are returned although in the model_checkpoint (not the optimizer state) the weights for A and B are stored.

Is there a way to load from that model_checkpoint where B was trained and A was basically fixed but for the second part of the training let A also be learnable?

Basically my question is, how can I resume from a (stage 2) checkpoint but not resuming from the last optimizer state since I want to "add" new weights that should be optimized that were not included in the optimizer before.

Thanks in advance and best!

DA-L3 avatar Sep 02 '22 09:09 DA-L3

Let me restate my understanding of your question. Please correct me as needed.

You have two phases of training. Before phase 1, your model state is A_0 and B_0. Your phase 1 is as follows: Phase 1: Trainable = B_0 fp16 checkpoint state = A_0 + B_1 fp32 checkpoint state = B_1

And for phase 2, you want to achieve the following: Phase 2: trainable = A_0 + B_1 fp16 checkpoint state = A_1 + B_2
fp32 checkpoint state = A_1 + B_2

Is this correct?

tjruwase avatar Sep 02 '22 16:09 tjruwase

Sorry for the late response and thank you for your quick reply!

But yes, this is quite correct. I am not so sure about the fp16 checkpoint state but fp32 seems correct. My main issue is that it seems like for phase 1, the fp32 checkpoint state seems to only contain B_1 and not A_0 (at least the optim state). The model checkpoint (called mp_[...]_model_states.pt) does contain the A_0 weights and I would like to initialize phase 2 using exactly A_0 + B_1 but it seems like it searches only for weights that are contained in the zero_[...]_optim_states.pt only which appear to only contain B_1 and not A_0.

DA-L3 avatar Sep 06 '22 06:09 DA-L3

Hi @CodyLDA, thx for letting us know the feature needed.

Could you also provide us with your training script so that we can reproduce your results and work on this new feature correctly?

GuanhuaWang avatar Nov 16 '22 18:11 GuanhuaWang

Closed it for now. Feel free to re-open it if needed.

GuanhuaWang avatar Dec 05 '22 18:12 GuanhuaWang