pe.callback.image.dpimagebench_lib.ema module

Adapted from DPImageBench: https://github.com/fjxmlzn/DPImageBench/blob/main/evaluation/ema.py

class pe.callback.image.dpimagebench_lib.ema.ExponentialMovingAverage(parameters, decay, use_num_updates=True)[source]

Bases: object

Maintains (exponential) moving average of a set of parameters.

__init__(parameters, decay, use_num_updates=True)[source]
Parameters:
  • parameters – Iterable of torch.nn.Parameter; usually the result of model.parameters().

  • decay – The exponential decay.

  • use_num_updates – Whether to use number of updates when computing averages.

copy_to(parameters)[source]

Copy current parameters into given collection of parameters.

Parameters:

parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored moving averages.

load_state_dict(state_dict)[source]
restore(parameters)[source]

Restore the parameters stored with the store method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the copy_to method. After validation (or model saving), use this to restore the former parameters.

Parameters:

parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored parameters.

state_dict()[source]
store(parameters)[source]

Save the current parameters for restoring later.

Parameters:

parameters – Iterable of torch.nn.Parameter; the parameters to be temporarily stored.

update(parameters)[source]

Update currently maintained parameters.

Call this every time the parameters are updated, such as the result of the optimizer.step() call.

Parameters:

parameters – Iterable of torch.nn.Parameter; usually the same set of parameters used to initialize this object.