PyTorch: Exponential Moving Average (EMA) Example

PyTorch Examples: EMA


This example carefully replicates the behavior of TensorFlow’s tf.train.ExponentialMovingAverage.

Notice that when applying EMA, only the trainable parameters should be changed; for PyTorch, we can get the trainable parameters by model.parameters() or model.named_parameters() where model is a torch.nn.Module.

Since my implementation creates a copy of the input model (i.e. shadow), the buffers needs to be copied to shadow whenever update() is invoked.

Alternative Implementation

You could implement shadow as a dict, for detail of this version see 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现. One problem with that implementation is that shadow needs to be manually saved since shadow parameters are not stored in state_dict; a simple fix to this problem is to register all shadow parameters by calling register_parameter(<parameter name>).


import torch
from torch import nn

from copy import deepcopy
from collections import OrderedDict
from sys import stderr

# for type hint
from torch import Tensor

class EMA(nn.Module):
    def __init__(self, model: nn.Module, decay: float):
        self.decay = decay

        self.model = model
        self.shadow = deepcopy(self.model)

        for param in self.shadow.parameters():

    def update(self):
        if not
            print("EMA update should only be called during training", file=stderr, flush=True)

        model_params = OrderedDict(self.model.named_parameters())
        shadow_params = OrderedDict(self.shadow.named_parameters())

        # check if both model contains the same set of keys
        assert model_params.keys() == shadow_params.keys()

        for name, param in model_params.items():
            # see
            # shadow_variable -= (1 - decay) * (shadow_variable - variable)
            shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))

        model_buffers = OrderedDict(self.model.named_buffers())
        shadow_buffers = OrderedDict(self.shadow.named_buffers())

        # check if both model contains the same set of keys
        assert model_buffers.keys() == shadow_buffers.keys()

        for name, buffer in model_buffers.items():
            # buffers are copied

    def forward(self, inputs: Tensor, return_feature: bool = False) -> Tensor:
            return self.model(inputs, return_feature)
            return self.shadow(inputs, return_feature)


Zijian Hu
Zijian Hu
Research Staff

My research interests include computer vision, machine learning, and robotics.