Transfer prompt parameter during training process.
Hi authors, thanks for great works. I have question in training process as prompt pool is selected exclusively based on task id. I guess the prompt pool is shared across the tasks. why we need to transfer prompt to new index? I think new index is trained on next task automatically.
Details of code part:
Transfer previous learned prompt params to the new prompt
if config.prompt_pool and config.prompt_pool_param.shared_prompt_pool: if task_id > 0: prev_start = (task_id - 1) * config.prompt_pool_param.top_k prev_end = task_id * config.prompt_pool_param.top_k cur_start = prev_end cur_end = (task_id + 1) * config.prompt_pool_param.top_k if (prev_end > config.prompt_pool_param.pool_size) or ( cur_end > config.prompt_pool_param.pool_size): pass else: param_dict = state.optimizer.target prompt_pool_para = param_dict["prompt_pool"]["prompt"] if config.use_prefix_tune_for_e_prompt: prompt_pool_para = prompt_pool_para.at[:, :, cur_start:cur_end].set( prompt_pool_para[:, :, prev_start:prev_end]) else: prompt_pool_para = prompt_pool_para.at[:, cur_start:cur_end].set( prompt_pool_para[:, prev_start:prev_end]) param_dict, _ = utils.replace_prompt_pool(param_dict, prompt_pool_para) state = utils.state_with_new_param(state, param_dict)