sophiat44
model upload
5a87d8d
import torch
class EMA(torch.nn.Module):
def __init__(self, model: torch.nn.Module, decay: float = 0.999):
super().__init__()
self.model = model
self.decay = decay
if hasattr(self.model, "time_geopath"):
self.time_geopath = self.model.time_geopath
# Put this in a buffer so that it gets included in the state dict
self.register_buffer("num_updates", torch.tensor(0))
self.shadow_params = torch.nn.ParameterList(
[
torch.nn.Parameter(p.clone().detach(), requires_grad=False)
for p in model.parameters()
if p.requires_grad
]
)
self.backup_params = []
def train(self, mode: bool):
if self.training and mode == False:
# Switching from train mode to eval mode. Backup the model parameters and
# overwrite with shadow params
self.backup()
self.copy_to_model()
elif not self.training and mode == True:
# Switching from eval to train mode. Restore the `backup_params`
self.restore_to_model()
super().train(mode)
def update_ema(self):
self.num_updates += 1
num_updates = self.num_updates.item()
decay = min(self.decay, (1 + num_updates) / (10 + num_updates))
with torch.no_grad():
params = [p for p in self.model.parameters() if p.requires_grad]
for shadow, param in zip(self.shadow_params, params):
shadow.sub_((1 - decay) * (shadow - param))
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def copy_to_model(self):
# copy the shadow (ema) parameters to the model
params = [p for p in self.model.parameters() if p.requires_grad]
for shaddow, param in zip(self.shadow_params, params):
param.data.copy_(shaddow.data)
def backup(self):
# Backup the current model parameters
if len(self.backup_params) > 0:
for p, b in zip(self.model.parameters(), self.backup_params):
b.data.copy_(p.data)
else:
self.backup_params = [param.clone() for param in self.model.parameters()]
def restore_to_model(self):
# Restores the backed up parameters to the model.
for param, backup in zip(self.model.parameters(), self.backup_params):
param.data.copy_(backup.data)