mmengine icon indicating copy to clipboard operation
mmengine copied to clipboard

[Feature] Nested initialization implementation of pure Python style configuration files

Open yinaoxiong opened this issue 2 years ago • 3 comments

What is the feature?

我非常喜欢mmengine的纯 Python 风格的配置的这个功能,使用这种配置文件可以自由的跳转并且可以任意定义和使用而不用再注册了。但是目前的配置文件在初始化只能进行单层初始化,但是对于模型的定义来说我们经常会进行嵌套的定义,下面是一个例子。

model = dict(
    type=xxModel,
    ...
    encoder=dict(
        type=Encoder,
        feature_extractor=dict(
            ....
        ),
        position_encoder=dict(
            type=SinusoidalPositionalEncoding,
            d_model=d_model,
            dropout=0.1,
            max_len=max_len,
        ),
        .....
    ),
    decoder=dict(
        type=Decoder,
        position_encoder=dict(
            type=SinusoidalPositionalEncoding,
            d_model=d_model,
            dropout=0.1,
            max_len=max_len,
        ),
        .....
    ),
)

如上所示,一个模型可能由encoder,和decoder组成,encoder再有其他部件组合而成,通常我们会希望可以任意替换组件的类型,只要外部forward表现一致就可以。对于这种嵌套初始化的需求可以尝试用下面这个装饰器实现。感觉这个需求还是挺常见的,要是感觉有用我可以发起一个pr,以及看看目前这个实现有没有什么问题。

import functools

def build_from_cfg(cfg: Union[dict, ConfigDict, Config]) -> Any:
    """
    Builds an object from a configuration dictionary.

    Args:
        cfg (Union[dict, ConfigDict, Config]): The configuration dictionary, which must contain the "type" key.

    Returns:
        Any: The built object.

    Raises:
        TypeError: If cfg is not a dict, ConfigDict, or Config type.
        KeyError: If "type" key is not in cfg.
        TypeError: If type is not a class or function type.
    """
    if not isinstance(cfg, (dict, ConfigDict, Config)):
        raise TypeError(
            f"cfg should be a dict, ConfigDict or Config, but got {type(cfg)}"
        )
    if "type" not in cfg:
        raise KeyError('`cfg`  must contain the key "type"')
    args = cfg.copy()
    obj_type = args.pop("type")
    if (
        inspect.isclass(obj_type)
        or inspect.isfunction(obj_type)
        or inspect.ismethod(obj_type)
    ):
        obj_cls = obj_type
    else:
        raise TypeError(
            f"type must be a class, function or method, but got {type(obj_type)}"
        )
    return obj_cls(**args)


def warp_cfg_args(func: Callable[..., Any]) -> Callable[..., Any]:
    """
    Decorator that wraps a function to convert its arguments from ConfigDict or Config objects to their corresponding
    Python objects.

    Args:
        func: The function to be wrapped.

    Returns:
        The wrapped function.
    """

    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        new_args = []

        def build_arg(arg: Any) -> Any:
            """
            Recursively converts ConfigDict or Config objects to their corresponding Python objects.

            Args:
                arg: The argument to be converted.

            Returns:
                The converted argument.
            """
            if isinstance(arg, (ConfigDict, Config)):
                arg = build_from_cfg(arg)
                return arg
            if (
                isinstance(arg, list)
                and len(arg) > 0
                and isinstance(arg[0], (ConfigDict, Config))
            ):
                arg = [build_from_cfg(a) for a in arg]
                return arg
            return arg

        for arg in args:
            arg = build_arg(arg=arg)
            new_args.append(arg)
        new_kwargs = {}
        for key, value in kwargs.items():
            value = build_arg(arg=value)
            new_kwargs[key] = value
        return func(*new_args, **new_kwargs)

    return wrapper

class xxModel(BaseModel):
    @warp_cfg_args
    def __init__(

Any other context?

No response

yinaoxiong avatar Jan 04 '24 12:01 yinaoxiong

非常感谢你的反馈,事实上我们也考虑过使用 dict(type=XXX) 的协议来自动 build 各个组件,这样就能免于在代码里通过 build_from_cfg 来显式地构建模块了。这样的想法确实很方便,但是在实际操作过程会遇到一些问题。以你提到的这种方式为例:

真的可以做到所有的模块,都遵循这种写法,通过在 __init__ 上加装饰器的方式来免于在 __init__ 内部调用 build function 么

这其实是很困难的,以大家都熟悉的 Dataloader 为例,Dataloader 需要接受 dataset 参数,sampler 也需要接受 dataset 参数,batch_sampler 需要 sampler 参数。对于这类构造参数互相耦合的情况,通过装饰器来自动化构建实例的方式就很难走通了,我们可能需要引入更多的概念,例如占位符。但是过于复杂语法涉及实际上是违反了 Pthon Style Config 的设计初衷。

但是如果我们选择绕开这个问题,让部分组件(参数互相独立)配置的写法用装饰器,而部分组件保持 build_function 的写法,我想这也是不合理的,因为你没法从配置文件里看出来哪些配置被 “区别对待” 了。

基于以上原因,我们目前仍然采取了在组件内部调用 build function 的做法,希望由用户来把握模块构建的过程。如果你有什么更好地想法,也非常欢迎在这个 issue 里继续讨论!

HAOCHENYE avatar Jan 07 '24 18:01 HAOCHENYE

build_function 确实可以方便用户使用非常灵活的方式 实现对于某一类组件的构建。 但是在实际使用中用户很少会去创建新的组件类别,大多都是直接定义或者注册一个mmengine中已有组件。 感觉这类装饰器可以看作是对build function 做法的一个补充的 辅助函数 而非替代, 这样用户在处理某些需要解耦的配置,例如 model 组件时 可以自由的选择是否使用(对于model 组件来说,实际上都使用也不会对原有的行为造成影响

yinaoxiong avatar Jan 18 '24 11:01 yinaoxiong

detectron2的lazyconfig比较容易实现嵌套实例化,对于共享模块,可以定义一个SharedCall,保证每次实例化返回的都是同一个对象即可

Asthestarsfalll avatar Feb 20 '24 08:02 Asthestarsfalll