pe.callback.image.dpimagebench_lib.ema module
Adapted from DPImageBench: https://github.com/fjxmlzn/DPImageBench/blob/main/evaluation/ema.py
- class pe.callback.image.dpimagebench_lib.ema.ExponentialMovingAverage(parameters, decay, use_num_updates=True)[source]
Bases:
object
Maintains (exponential) moving average of a set of parameters.
- __init__(parameters, decay, use_num_updates=True)[source]
- Parameters:
parameters – Iterable of torch.nn.Parameter; usually the result of model.parameters().
decay – The exponential decay.
use_num_updates – Whether to use number of updates when computing averages.
- copy_to(parameters)[source]
Copy current parameters into given collection of parameters.
- Parameters:
parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored moving averages.
- restore(parameters)[source]
Restore the parameters stored with the store method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the copy_to method. After validation (or model saving), use this to restore the former parameters.
- Parameters:
parameters – Iterable of torch.nn.Parameter; the parameters to be updated with the stored parameters.