pe.callback.image.dpimagebench_classify_images module
- class pe.callback.image.dpimagebench_classify_images.DPImageBenchClassifyImages(model_name, test_data, val_data, batch_size=256, num_epochs=50, n_splits=1, lr=0.01, lr_scheduler_step_size=20, lr_scheduler_gamma=0.2, ema_rate=0.9999, **model_params)[source]
Bases:
Callback
The callback that evaluates the classification accuracy of the synthetic data following DPImageBench (https://github.com/2019ChenGong/DPImageBench).
- __call__(syn_data)[source]
This function is called after each PE iteration that computes the downstream classification metrics.
- Parameters:
syn_data (
pe.data.Data
) – The synthetic data- Returns:
The classification accuracy metrics
- Return type:
- __init__(model_name, test_data, val_data, batch_size=256, num_epochs=50, n_splits=1, lr=0.01, lr_scheduler_step_size=20, lr_scheduler_gamma=0.2, ema_rate=0.9999, **model_params)[source]
Constructor.
- Parameters:
model_name (str) – The name of the model to use (wrn, resnet, resnext)
test_data (
pe.data.Data
) – The test dataval_data (
pe.data.Data
) – The validation databatch_size (int, optional) – The batch size, defaults to 256
num_epochs (int, optional) – The number of training epochs, defaults to 50
n_splits (int, optional) – The number of splits for gradient accumulation, defaults to 1
lr (float, optional) – The learning rate, defaults to 0.01
lr_scheduler_step_size (int, optional) – The step size for the learning rate scheduler, defaults to 20
lr_scheduler_gamma (float, optional) – The gamma for the learning rate scheduler, defaults to 0.2
ema_rate (float, optional) – The rate for the exponential moving average, defaults to 0.9999
- _get_data_loader(data)[source]
Getting the data loader.
- Parameters:
data (
pe.data.Data
) – The data object- Returns:
The data loader
- Return type:
torch.utils.data.DataLoader
- _get_images_and_label_from_data(data)[source]
Getting images and labels from the data.
- Parameters:
data (
pe.data.Data
) – The data object- Returns:
The images and labels
- Return type:
tuple[np.ndarray, np.ndarray]
- _get_model()[source]
Getting the model.
- Raises:
ValueError – If the model name is unknown
- Returns:
The model
- Return type:
torch.nn.Module
- evaluate(model, ema, data_loader, criterion)[source]
Evaluating the model.
- Parameters:
model (torch.nn.Module) – The model
ema (
pe.callback.image.dpimagebench_lib.ema.ExponentialMovingAverage
) – The exponential moving average objectdata_loader (torch.utils.data.DataLoader) – The data loader
criterion (torch.nn.Module) – The criterion
- Returns:
The accuracy and loss
- Return type:
tuple[float, float]