pe.callback.image.dpimagebench_classify_images module

class pe.callback.image.dpimagebench_classify_images.DPImageBenchClassifyImages(model_name, test_data, val_data, batch_size=256, num_epochs=50, n_splits=1, lr=0.01, lr_scheduler_step_size=20, lr_scheduler_gamma=0.2, ema_rate=0.9999, **model_params)[source]

Bases: Callback

The callback that evaluates the classification accuracy of the synthetic data following DPImageBench (https://github.com/2019ChenGong/DPImageBench).

__call__(syn_data)[source]

This function is called after each PE iteration that computes the downstream classification metrics.

Parameters:

syn_data (pe.data.Data) – The synthetic data

Returns:

The classification accuracy metrics

Return type:

list[pe.metric_item.FloatListMetricItem]

__init__(model_name, test_data, val_data, batch_size=256, num_epochs=50, n_splits=1, lr=0.01, lr_scheduler_step_size=20, lr_scheduler_gamma=0.2, ema_rate=0.9999, **model_params)[source]

Constructor.

Parameters:
  • model_name (str) – The name of the model to use (wrn, resnet, resnext)

  • test_data (pe.data.Data) – The test data

  • val_data (pe.data.Data) – The validation data

  • batch_size (int, optional) – The batch size, defaults to 256

  • num_epochs (int, optional) – The number of training epochs, defaults to 50

  • n_splits (int, optional) – The number of splits for gradient accumulation, defaults to 1

  • lr (float, optional) – The learning rate, defaults to 0.01

  • lr_scheduler_step_size (int, optional) – The step size for the learning rate scheduler, defaults to 20

  • lr_scheduler_gamma (float, optional) – The gamma for the learning rate scheduler, defaults to 0.2

  • ema_rate (float, optional) – The rate for the exponential moving average, defaults to 0.9999

_get_data_loader(data)[source]

Getting the data loader.

Parameters:

data (pe.data.Data) – The data object

Returns:

The data loader

Return type:

torch.utils.data.DataLoader

_get_images_and_label_from_data(data)[source]

Getting images and labels from the data.

Parameters:

data (pe.data.Data) – The data object

Returns:

The images and labels

Return type:

tuple[np.ndarray, np.ndarray]

_get_model()[source]

Getting the model.

Raises:

ValueError – If the model name is unknown

Returns:

The model

Return type:

torch.nn.Module

evaluate(model, ema, data_loader, criterion)[source]

Evaluating the model.

Parameters:
Returns:

The accuracy and loss

Return type:

tuple[float, float]