PyTorch: Exponential Moving Average (EMA) Example
PyTorch Examples: EMA
Introduction
This example carefully replicates the behavior of TensorFlow’s tf.train.ExponentialMovingAverage.
Notice that when applying EMA, only the trainable parameters should be changed; for PyTorch, we can get the trainable parameters by model.parameters()
or model.named_parameters()
where model
is a torch.nn.Module
.
Since my implementation creates a copy of the input model (i.e. shadow
), the buffers needs to be copied to shadow
whenever update()
is invoked.
Alternative Implementation
You could implement shadow
as a dict
, for detail of this version see 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现. One problem with that implementation is that shadow needs to be manually saved since shadow parameters are not stored in state_dict
; a simple fix to this problem is to register all shadow parameters by calling register_parameter(<parameter name>)
.
Implementations
import torch
from torch import nn
from copy import deepcopy
from collections import OrderedDict
from sys import stderr
# for type hint
from torch import Tensor
class EMA(nn.Module):
def __init__(self, model: nn.Module, decay: float):
super().__init__()
self.decay = decay
self.model = model
self.shadow = deepcopy(self.model)
for param in self.shadow.parameters():
param.detach_()
@torch.no_grad()
def update(self):
if not self.training:
print("EMA update should only be called during training", file=stderr, flush=True)
return
model_params = OrderedDict(self.model.named_parameters())
shadow_params = OrderedDict(self.shadow.named_parameters())
# check if both model contains the same set of keys
assert model_params.keys() == shadow_params.keys()
for name, param in model_params.items():
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
model_buffers = OrderedDict(self.model.named_buffers())
shadow_buffers = OrderedDict(self.shadow.named_buffers())
# check if both model contains the same set of keys
assert model_buffers.keys() == shadow_buffers.keys()
for name, buffer in model_buffers.items():
# buffers are copied
shadow_buffers[name].copy_(buffer)
def forward(self, inputs: Tensor, return_feature: bool = False) -> Tensor:
if self.training:
return self.model(inputs, return_feature)
else:
return self.shadow(inputs, return_feature)