DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Recommended way to implement EMA

Open taoisu opened this issue 3 years ago • 27 comments

Describe the bug A clear and concise description of what the bug is.

Hi deepspeed team, I have some code that uses exponential moving average (EMA) for training a UNet model, the code relies on the named_parameters() and parameters() of the model to store and update the params, the simplified impl is like below:


# Part 0 ====================================================================

class LitEma(nn.Module):

    def __init__(
        self,
        model:nn.Module,
        decay:float=0.9999,
        use_num_upates:bool=True,
    ):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')

        self.m_name2s_name = {}
        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
        self.register_buffer(
            'num_updates',
            torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int))

        for name, p in model.named_parameters():
            if p.requires_grad:
                #remove as '.'-character is not allowed in buffers
                s_name = name.replace('.','')
                self.m_name2s_name.update({ name: s_name })
                self.register_buffer(s_name, p.clone().detach().data)

        self.collected_params = []

    def forward(self, model:nn.Module):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1+self.num_updates)/(10+self.num_updates))

        one_minus_decay = 1.0 - decay

        with torch.no_grad():
            m_param = dict(model.named_parameters())
            shadow_params = dict(self.named_buffers())

            for key in m_param:
                if m_param[key].requires_grad:
                    sname = self.m_name2s_name[key]
                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                    shadow_params[sname].sub_(one_minus_decay*(shadow_params[sname]-m_param[key]))
                else:
                    assert not key in self.m_name2s_name

    def copy_to(self, model:nn.Module):
        m_param = dict(model.named_parameters())
        shadow_params = dict(self.named_buffers())
        for key in m_param:
            if m_param[key].requires_grad:
                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
            else:
                assert not key in self.m_name2s_name

    def store(self, parameters:Iterable[nn.Parameter]):
        """
        Save the current parameters for restoring later.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            temporarily stored.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters:Iterable[nn.Parameter]):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

# Part 1 ====================================================================

    if self.use_ema:
        self.model_ema = LitEma(self.model)
        print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

# Part 2 ====================================================================

    @contextmanager
    def ema_scope(self, context:str=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

# Part 3 ====================================================================

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

When using deepspeed, the .parameters() and .named_parameters() all returns empty, I'm wondering what is the recommended way of implementing the above LitEma class with deepspeed? Sorry if this seems to be a dumb question, but I'm new here and with offload and sharding it seems unclear to me how to implement it correctly.

To Reproduce Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior A clear and concise description of what you expected to happen.

ds_report output Please run ds_report to give us details about your setup.

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • GPU count and types [e.g. two machines with x8 A100s each]
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version
  • Any other relevant info about your setup

Launcher context Are you launching your experiment with the deepspeed launcher, MPI, or something else?

Docker context Are you using a specific docker image that you can share?

Additional context Add any other context about the problem here.

taoisu avatar Jun 26 '22 22:06 taoisu

Hi taoisu. I have exactly the same issue. Have you solved it yet? Can you share your solutions please?

ssyang1999 avatar Mar 13 '23 11:03 ssyang1999

Hi, I try to work out a usable EMA module with Zero Stage 3. See below:

from deepspeed.runtime.zero import GatheredParameters

class DSEma(nn.Module):
    def __init__(self, model, decay=0.9999, use_num_updates=True):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')

        self.m_name2s_name = {}
        self.decay = decay
        self.num_updates = 0 if use_num_updates else -1

        with GatheredParameters(model.parameters(), fwd_module=self):
            for name, p in model.named_parameters():
                if p.requires_grad:
                    # remove as '.'-character is not allowed in buffers
                    s_name = name.replace('.', '')
                    self.m_name2s_name.update({name: s_name})
                    self.register_buffer(s_name, p.clone().detach().data)
                    # remove as '.'-character is not allowed in buffers
        self.collected_params = []

    def forward(self, model):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        one_minus_decay = 1.0 - decay
        shadow_params = dict(self.named_buffers())

        with torch.no_grad():
            with GatheredParameters(model.parameters()):
                if deepspeed.comm.get_rank() == 0:
                    m_param = dict(model.named_parameters())

                    for key in m_param:
                        if m_param[key].requires_grad:
                            sname = self.m_name2s_name[key]
                            shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                            shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
                        else:
                            assert not key in self.m_name2s_name

    def copy_to(self, model):
        shadow_params = dict(self.named_buffers())
        with GatheredParameters(model.parameters(), modifier_rank=0):
            if deepspeed.comm.get_rank() == 0:
                m_param = dict(model.named_parameters())
                for key in m_param:
                    if m_param[key].requires_grad:
                        m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
                    else:
                        assert not key in self.m_name2s_name

    def store(self, model):
        """
        Save the current parameters for restoring later.
        Args:
          model: A model that parameters will be stored
        """
        with GatheredParameters(model.parameters()):
            if deepspeed.comm.get_rank() == 0:
                parameters = model.parameters()
                self.collected_params = [param.clone() for param in parameters]

    def restore(self, model):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          model: A model that to restore its parameters.
        """
        with GatheredParameters(model.parameters(), modifier_rank=0):
            if deepspeed.comm.get_rank() == 0:
                parameters = model.parameters()
                for c_param, param in zip(self.collected_params, parameters):
                    param.data.copy_(c_param.data)

ssyang1999 avatar Mar 13 '23 13:03 ssyang1999

Hi, I try to work out a usable EMA module with Zero Stage 3. See below:

from deepspeed.runtime.zero import GatheredParameters

class DSEma(nn.Module):
    def __init__(self, model, decay=0.9999, use_num_updates=True):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')

        self.m_name2s_name = {}
        self.decay = decay
        self.num_updates = 0 if use_num_updates else -1

        with GatheredParameters(model.parameters(), fwd_module=self):
            for name, p in model.named_parameters():
                if p.requires_grad:
                    # remove as '.'-character is not allowed in buffers
                    s_name = name.replace('.', '')
                    self.m_name2s_name.update({name: s_name})
                    self.register_buffer(s_name, p.clone().detach().data)
                    # remove as '.'-character is not allowed in buffers
        self.collected_params = []

    def forward(self, model):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        one_minus_decay = 1.0 - decay
        shadow_params = dict(self.named_buffers())

        with torch.no_grad():
            with GatheredParameters(model.parameters()):
                if deepspeed.comm.get_rank() == 0:
                    m_param = dict(model.named_parameters())

                    for key in m_param:
                        if m_param[key].requires_grad:
                            sname = self.m_name2s_name[key]
                            shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                            shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
                        else:
                            assert not key in self.m_name2s_name

    def copy_to(self, model):
        shadow_params = dict(self.named_buffers())
        with GatheredParameters(model.parameters(), modifier_rank=0):
            if deepspeed.comm.get_rank() == 0:
                m_param = dict(model.named_parameters())
                for key in m_param:
                    if m_param[key].requires_grad:
                        m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
                    else:
                        assert not key in self.m_name2s_name

    def store(self, model):
        """
        Save the current parameters for restoring later.
        Args:
          model: A model that parameters will be stored
        """
        with GatheredParameters(model.parameters()):
            if deepspeed.comm.get_rank() == 0:
                parameters = model.parameters()
                self.collected_params = [param.clone() for param in parameters]

    def restore(self, model):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          model: A model that to restore its parameters.
        """
        with GatheredParameters(model.parameters(), modifier_rank=0):
            if deepspeed.comm.get_rank() == 0:
                parameters = model.parameters()
                for c_param, param in zip(self.collected_params, parameters):
                    param.data.copy_(c_param.data)

Hi, have you solved the problem?

czczup avatar Apr 03 '23 08:04 czczup

@czczup Hi cz. It does run EMA for me in zero stage 3. But I didn't check whether it behaves exactly the same as LitEMA.

ssyang1999 avatar Apr 03 '23 08:04 ssyang1999

Hey trying to use this code myself—could you please share how this is being implemented within your lightning module? I'm currently trying something like this and appears to not be workinng:

In the init:

self.ema_model = DSEma(self.model)

Then:

def on_train_batch_end(self, *args, **kwargs):
        self.ema_model(self.model)

KirillShmilovich avatar Jun 02 '23 18:06 KirillShmilovich

@taoisu, @KirillShmilovich, @ssyang1999, @czczup,

Please note that EMA is used in the recently released DeepSpeed-Chat. This utility script contains examples of zero stage parameter usage including EMA.

Hope this helps.

tjruwase avatar Jun 02 '23 19:06 tjruwase

@tjruwase

I tried implementing the EMA used in the utility script linked above in the following manner:

in the init:

self.model = model
self.ema_model = copy.deepcopy(self.model)

Then:

def on_train_batch_end(self, *args, **kwargs):
        moving_average(self.model, self.ema_model, zero_stage=3)

Using the same moving_average utility function provided in the utility script. However, this appears to not work: I confirmed that the model weights were indeed being updated within the moving_average function but when performing inference the ema_model appears to not be working correctly. For example, when specifying beta=0.0 corresponds to the case when the ema_model and model weights should be exactly the same (should produce a carbon copy at each on_train_batch_end). Nevertheless, the EMA model appears to produce entirely different results than the model while they should in principle be the same in this setting.

Do you have any ideas if my implementation with lightning is correct or if there are other confounding factors that could be at play?

KirillShmilovich avatar Jun 07 '23 18:06 KirillShmilovich

@KirillShmilovich, can you confirm that the ema model produces identical right after init? In other words, can you run the inference step right after init? The reason I ask is that I don't think deepcopy is the correct way to clone a zero stage 3 model because of the parameter partitioning.

tjruwase avatar Jun 08 '23 15:06 tjruwase

What would be the recommended strategy for creating a copy of the model for proper recognition of parameter partitioning?

KirillShmilovich avatar Jun 08 '23 15:06 KirillShmilovich

The following code should work (note not tested). Please see here for guide on manipulating z3 models.

def clone_zero_model(src_model, dst_model, zero_stage=0):
    zero_stage_3 = (zero_stage == 3)
    with torch.no_grad():
        for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
            # TODO: use prefiltering for efficiency
            params_to_fetch = _z3_params_to_fetch([src_param, dst_param
                                                   ]) if zero_stage_3 else []
            should_gather_param = len(params_to_fetch) > 0
            with deepspeed.zero.GatheredParameters(
                    params_to_fetch, enabled=should_gather_param):
                dst_param.data.copy_(src_param.data)

For a z3 model can be used as follows:

  main_model = ...
  ema_model = ... # constructed similarly to main_model

def on_train_batch_end(self, *args, **kwargs):
   clone_zero_model(src_model=main_model, dst_model=ema_model, zero_stage=3)

tjruwase avatar Jun 08 '23 20:06 tjruwase

Great, thank you will test this. One question: it should not be necessary to call clone_zero_model on each on_train_batch_end? If I understand correctly, clone_zero_model should just need to be called once during the init and then moving_average (from above) should be called on_train_batch_end to update the ema_model?

KirillShmilovich avatar Jun 08 '23 20:06 KirillShmilovich

For reference, I tested the above code and it appears yo yield the same problems.

Specifically:

in the init:

model = compile_model(args)
ema_model = compile_model(args)

self.model = model
self.ema_model = ema_model

clone_zero_model(self.model, self.ema_model, zero_stage=3)

then:

def on_train_batch_end(self, *args, **kwargs):
        clone_zero_model(self.model, self.ema_model, zero_stage=3)

In this case, we should expect the ema_model to be copied exactly from model after each epoch however the same error occurs here where the ema_model and model outputs are wildly different. Visual inspection of the parameters appears the confirm the weights are being copied correctly. Do you have any ideas for debugging if there may exist any hidden DeepSpeed at play?

KirillShmilovich avatar Jun 08 '23 21:06 KirillShmilovich

At this point, it would be helpful to see how the outputs are generated. Is it possible to share an e2e code so I can repro locally?

tjruwase avatar Jun 08 '23 23:06 tjruwase

I think I put together a minimal working example here.

This uses clone_zero_model and runs on 2 GPUs. At the end I do trainer.validate(...) where we'd expect the output to be exactly the same for both the ema and non-ema model, but they print different resutls:

image

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer

import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

from deepspeed.ops.adam import DeepSpeedCPUAdam


def _z3_params_to_fetch(param_list):
    return [
        p for p in param_list
        if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
    ]


def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0):
    zero_stage_3 = (zero_stage == 3)
    with torch.no_grad():
        for param, param_ema in zip(model.parameters(),
                                    model_ema.parameters()):
            # TODO: use prefiltering for efficiency
            params_to_fetch = _z3_params_to_fetch([param, param_ema
                                                   ]) if zero_stage_3 else []
            should_gather_param = len(params_to_fetch) > 0
            with deepspeed.zero.GatheredParameters(
                    params_to_fetch, enabled=should_gather_param):
                data = param.data
                if device is not None:
                    data = data.to(device)
                param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))

def clone_zero_model(src_model, dst_model, zero_stage=0):
    zero_stage_3 = (zero_stage == 3)
    with torch.no_grad():
        for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
            # TODO: use prefiltering for efficiency
            params_to_fetch = _z3_params_to_fetch([src_param, dst_param
                                                   ]) if zero_stage_3 else []
            should_gather_param = len(params_to_fetch) > 0
            with deepspeed.zero.GatheredParameters(
                    params_to_fetch, enabled=should_gather_param):
                dst_param.data.copy_(src_param.data)


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len
    
    
num_samples = 10000


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.ema_layer = torch.nn.Linear(32, 2)

        clone_zero_model(self.layer, self.ema_layer, zero_stage=3)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)
        ema_loss = self.ema_layer(batch).sum()
        self.log("ema_valid_loss", ema_loss)
    
    def on_train_batch_end(self, *args, **kwargs):
        clone_zero_model(src_model=self.layer, dst_model=self.ema_layer, zero_stage=3)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return DeepSpeedCPUAdam(self.layer.parameters())

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()

    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        devices=2,
        strategy="deepspeed_stage_3_offload",
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.validate(model, dataloaders=test_data)
    trainer.test(model, dataloaders=test_data)

run()

KirillShmilovich avatar Jun 09 '23 00:06 KirillShmilovich

As another point of reference, setting devices=1 fixes the issue and results in ema_val_loss and val_loss becoming exactly equivalent.

KirillShmilovich avatar Jun 09 '23 15:06 KirillShmilovich

@tjruwase just want to follow up on using the moving_average to create an EMA model

  1. does the EMA model also have to have the same zero config as the original model, or can it be offload to CPU only?
  2. instead of using clone_zero_model, can we deepcopy the model before preparing with deepspeed? that would alleviate the zero3 param gather issue

for example, i'm thinking

# no deepspeed init, deepcopy
model = SomeModel()
ema_model = copy.deepcopy(model)
model_engine, ... = deepspeed.initialize(model, ...)

# not sure if required to initialize ema model, or keeping to cpu is sufficient
# ema_model, *_ = deepspeed.initialize(ema_model, ...)

def train_step():
    ...
    moving_average(model_engine, ema_model)

Would this work? or are the 2 condition required? so we'd have

# deepspeed init and clone zero model
model, ema_model = SomeModel(), SomeModel()
model_engine, ... = deepspeed.initialize(model, ...)
ema_model, *_ = deepspeed.initialize(ema_model, ...)

clone_zero_model(model_engine, ema_model, zero_stage)


def train_step():
    ...
    moving_average(model_engine, ema_model)

maxmatical avatar Aug 28 '23 15:08 maxmatical

@maxmatical,

  1. EMA model does not have to have same zero config since it is simpler model, e.g., not having an optimizer.
  2. copy.deepcopy() cannot replicate the behavior of gathering parameter partitions from CPU, NVMe, or remote HBM. clone_zero_model() uses deepspeed.zero.GatheredParameters() to achieve this.

tjruwase avatar Aug 31 '23 01:08 tjruwase

hi @tjruwase

i believe we may have some misunderstanding in my questions

  1. my question is more on whether it is required to call deepspeed.initialize on the EMA model (as seen in the deepspeed chat code), or can the EMA model remain on CPU? so the example would be
model, ema_model = SomeModel(), SomeModel()
model, ... = deepspeed.initialize(model, ...)

def train_step():
    ...
    moving_average(model_engine, ema_model)

Will this raise any issues? the reasoning for this would be if we don't have to call deepspeed.initialize(ema_model), then we should save some gpu memory by keeping the ema_model entirely on cpu

  1. My question on calling ema_model = copy.deepcopy(model) would be done before deepspeed.initialize(model, ...), in this case since the model is still just on CPU, calling deepcopy should be fine right? Although I suppose in this case there shouldn't be much difference between that and just initializing the same model twice via
model, ema_model = SomeModel(), SomeModel()

maxmatical avatar Aug 31 '23 23:08 maxmatical

@maxmatical, thanks for the clarification. Apologies for misunderstanding your question.

  1. You are correct, ema_model does not need deepspeed.initialize(). It is better to keep ema_model in cpu.
  2. copy.deepcopy() will work as long as model was not created using zero.Init().

tjruwase avatar Sep 01 '23 00:09 tjruwase

thinking more about it, i can see maybe some concerns with initializing the ema_model with copy.deepcopy

the scenario i'm thinking of is if i initialize ema_model and keep it on cpu, after i call deepspeed.initialize() on model, there could be a dtype mismatch due to mixed precision with fp16/bf16 that i'm training the model in, and the ema_model will be still on fp32, so maybe the best practice here is

  1. initialize ema_model however
  2. after calling model = deepspeed.initialize(model), also call clone_zero_model() to copy the weights to ema_model so the parameter types match, but keep ema_model on cpu
  3. call moving_average() in the training loop

i assume this would be more memory efficient, as we don't need to keep another copy of ema_model on gpu, but another consideration i have would be the latency of calling moving_average when ema_model is on cpu, would there speed bottleneck here that would make training less efficient?

@tjruwase would love to hear your thoughts on what the best practice may be

maxmatical avatar Sep 01 '23 03:09 maxmatical