libai icon indicating copy to clipboard operation
libai copied to clipboard

libai trainer设计文档

Open CPFLAME opened this issue 4 years ago • 5 comments

调研了一下detectron2, mmdet, ColossalAI, paddledetection

paddledetection

paddledetection则是直接定义了一个object, 比较冗长, 主要有trainer.train()函数去协调各个模块, 整体下来感官和ColoraIAI的并没有本质区别 如果要加新功能, 则在这个object里面进行改动, 改动会对旧版本造成较大的影响 trainer.py

detectron2 && mmdet

这两者的设计思路其实差不多, 先定义一个HookBase

class HookBase:

    def before_train(self):
        """
        Called before the first iteration.
        """
        pass

    def after_train(self):
        """
        Called after the last iteration.
        """
        pass

    def before_step(self):
        """
        Called before each iteration.
        """
        pass

    def after_step(self):
        """
        Called after each iteration.
        """
        pass

    def state_dict(self):
        """
        Hooks are stateless by default, but can be made checkpointable by
        implementing `state_dict` and `load_state_dict`.
        """
        return {}

然后把和训练相关的一些步骤, 全部继承HookBase, 打包成一个list送到trainer里面去就可以了. 在train()函数里面进行统一的调用 比如LR_scheduler, optimizer, write_metrics, save_model, eval_metric 都可以继承HookBase, 各自分开写成一个 HookBase的子类, 这样可以一目了然的查看这个模块在训练的哪个阶段进行了什么操作, 不容易出错

class TrainerBase:

   def __init__(self) -> None:
        self._hooks: List[HookBase] = []
        self.iter: int = 0
        self.start_iter: int = 0
        self.max_iter: int
        self.storage: EventStorage
        _log_api_usage("trainer." + self.__class__.__name__)

    def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
        """
        Register hooks to the trainer. The hooks are executed in the order
        they are registered.
        Args:
            hooks (list[Optional[HookBase]]): list of hooks
        """
        hooks = [h for h in hooks if h is not None]
        for h in hooks:
            assert isinstance(h, HookBase)
            # To avoid circular reference, hooks and trainer cannot own each other.
            # This normally does not matter, but will cause memory leak if the
            # involved objects contain __del__:
            # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
            h.trainer = weakref.proxy(self)
        self._hooks.extend(hooks)
        
   def train(self, start_iter: int, max_iter: int):
        """
        Args:
            start_iter, max_iter (int): See docs above
        """
        logger = logging.getLogger(__name__)
        logger.info("Starting training from iteration {}".format(start_iter))

        self.iter = self.start_iter = start_iter
        self.max_iter = max_iter

        with EventStorage(start_iter) as self.storage:
            try:
                self.before_train()
                for self.iter in range(start_iter, max_iter):
                    self.before_step()
                    self.run_step()
                    self.after_step()
                # self.iter == max_iter can be used by `after_train` to
                # tell whether the training successfully finished or failed
                # due to exceptions.
                self.iter += 1
            except Exception:
                logger.exception("Exception during training:")
                raise
            finally:
                self.after_train()

    def before_train(self):
        for h in self._hooks:
            h.before_train()

    def after_train(self):
        self.storage.iter = self.iter
        for h in self._hooks:
            h.after_train()

    def before_step(self):
        # Maintain the invariant that storage.iter == trainer.iter
        # for the entire execution of each step
        self.storage.iter = self.iter

        for h in self._hooks:
            h.before_step()

    def after_step(self):
        for h in self._hooks:
            h.after_step()

    def run_step(self):
        raise NotImplementedError

ColossalAI

其中ColossaIAI的trainer和detection2以及mmdet 有一定的共同之处, 但是模块划分没有那么鲜明, 有点像介于paddledetection 和 detection2&&mmdet之间的结合体, 在train的时候仍然需要在函数中写optimizer.zero_grad()等 trainer

    def _train_epoch(self,
                     train_dataloader: DataLoader,
                     epoch: int = None,
                     display_progress: bool = False):
        # set training state
        self._engine.train()
        data_iter = iter(train_dataloader)
        progress = range(self._steps_per_epoch)
        if display_progress:
            if epoch is None:
                progress = tqdm(progress, desc='[Train]')
            else:
                progress = tqdm(progress, desc=f'[Epoch {epoch} train]')

        self._call_hooks('before_train_epoch')
        self._call_timer(action='start', item='train-epoch')
        for i in progress:
            self._call_hooks('before_train_iter')
            self._call_timer(action='start', item='train-step')

            # run 1 training step
            self.engine.zero_grad()
            logits, label, loss = self.schedule.forward_backward_step(
                self.engine, data_iter, forward_only=False, return_loss=True)
            self.engine.step()
            self._call_timer(action='stop', item='train-step', keep_in_history=True)
            self._call_hooks('after_train_iter', output=(logits, label, loss))

            self._cur_step += 1

            # stop when max iter is reached
            if self._exceed_max_step():
                break

        self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
        self._call_hooks('after_train_epoch')
        self._call_timer(action='reset', item='train-step')

个人倾向于detectron2 && mmdet的设计思路, 欢迎各位补充

CPFLAME avatar Dec 14 '21 07:12 CPFLAME

先明确trainer的需求,它只负责训练相关的,还是将创建data_loader、model、optimizer等也纳入trainer。如果只是训练相关的,那么它的功能应当包括:

  1. 训练并更新参数
  2. 验证模型,计算valid_loss
  3. 加载、保存模型
  4. 记录日志
  5. 评价指标metric(可能通过评价指标判断早停)
  6. 分布式配置,cpu、gpu切换,amp设置
  7. 恢复训练
  8. 使用模型预测(待讨论)

另外,这里的模型是eager还是graph,如果是graph,那么eager to graph的转换写在哪里,是用trainer封装,还是在models里由用户写,这个需要商量一下。

其他可参考的模型库有:pytorch-lightning,huggingface transformers。pytorch-lightning推荐看一下,很符合研究者的习惯。不过这些库都封装得太好了,外部看起来不知道从哪里下手,建议列完需求,把需求先实现即可,不用一次性做到完善。

dangkai4u avatar Dec 14 '21 08:12 dangkai4u

trainer里面的功能我在上面简单罗列了一下, 我理解trainer应该只用关心训练步骤, 我们可以先定义一个base_trainer 其实在detectron2里面都有写的都比较全了, defeult_trainer.py. 在这个base_trainer里面以下几个功能我们是可以确定的:

  1. optimizer
  2. lr_scheduler
  3. 模型验证
  4. 日志记录
  5. 打印训练信息metric
  6. 评价指标metric
  7. resume训练(这个功能不知道目前oneflow是否完善, 可以先空着)
  8. 模型加载和保存
  9. dataloader的创建在trainer里面可以写一个build_train_loader(), 然后调用写好的dataset创建接口. 如build_train_loader

eager && graph

由于oneflow的特殊性, 分为eager和graph, 所以我们可以有一个eager_trainer和graph_trainer, 来继承base_trainer. 其中eager_trainer, 基本功能几乎和base_trainer一样, 且不支持Fp16 graph_trainer, 则是要把optimizer这个功能去掉, 直接写到model里面去,

分布式和AMP

至于分布式配置在oneflow里面应该直接用launcher启动就不用管了, eager模式下用ddp包一层.可以写在trainer里面,

至于eager to graph的转换应该是由用户在models里面写好. amp设置直接在graph的model里面配置就好了.

其他参考的模型库

pytorch-lightning的trainer写的很烂, 让人点开了就想关掉, 不建议参考这个. trainer.py huggingface transformers 的trainer大体方向上和detectron2的构造比较像, 但是远不如其轻便, trainer.train()有将近500行代码, 感觉不太适合参考 trainer.train()

CPFLAME avatar Dec 14 '21 08:12 CPFLAME

eager to graph的转换不能做出抽象吗,如果能做出来,那用户就几乎感知不到eager和graph的区别了,也可以避免每写一个模型,就定义两个graph。建议问问框架组的同事。很多步骤这些都是固定的,只是无法使用**kwargs传参,想办法绕过,或者写个基类,模型传参部分由用户实现,其余部分由框架规定。

dangkai4u avatar Dec 14 '21 09:12 dangkai4u

这个功能可以先放着吧, eager to graph的转换, 我们可以给一个参考样例, 用户一看应该就是就明白了. 后续有比较好的想法我们可以做抽象

CPFLAME avatar Dec 14 '21 09:12 CPFLAME

eager to graph的转换不能做出抽象吗,如果能做出来,那用户就几乎感知不到eager和graph的区别了,也可以避免每写一个模型,就定义两个graph。建议问问框架组的同事。很多步骤这些都是固定的,只是无法使用**kwargs传参,想办法绕过,或者写个基类,模型传参部分由用户实现,其余部分由框架规定。

之前基于老的函数接口,参考lightning设计,做过一个解决该问题的设计。当时想的是一套代码,支持动静切换 + train和eval,可以参考。

例子: https://github.com/Oneflow-Inc/oneflow/blob/master/python/oneflow/compatible/single_client/test/models/test_alexnet_model.py#L68

实现: https://github.com/Oneflow-Inc/oneflow/blob/master/python/oneflow/framework/model.py#L206

虽然比keras和mindspore要灵活,但是做完后感觉还是比较笨重和受限。需要按约定的接口写代码。另外动静的优化方式也不同,还是存在很多细节问题的。所以我的观点和 @CPFLAME 类似,可以先不做太高层次的API封装,而是做可拼装的组件。

这是当时的一个调研和对比:https://github.com/Oneflow-Inc/OneTeam/issues/193#issuecomment-768853923

strint avatar Dec 14 '21 09:12 strint