Source code for pe.callback.image.dpimagebench_classify_images

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from pe.callback.callback import Callback
from pe.constant.data import IMAGE_DATA_COLUMN_NAME
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.metric_item import FloatListMetricItem
from pe.logging import execution_logger

from .dpimagebench_lib.wrn import WideResNet
from .dpimagebench_lib.resnet import ResNet
from .dpimagebench_lib.resnext import ResNeXt
from .dpimagebench_lib.ema import ExponentialMovingAverage


[docs] class DPImageBenchClassifyImages(Callback): """The callback that evaluates the classification accuracy of the synthetic data following DPImageBench (https://github.com/2019ChenGong/DPImageBench). """
[docs] def __init__( self, model_name, test_data, val_data, batch_size=256, num_epochs=50, n_splits=1, lr=0.01, lr_scheduler_step_size=20, lr_scheduler_gamma=0.2, ema_rate=0.9999, **model_params, ): """Constructor. :param model_name: The name of the model to use (wrn, resnet, resnext) :type model_name: str :param test_data: The test data :type test_data: :py:class:`pe.data.Data` :param val_data: The validation data :type val_data: :py:class:`pe.data.Data` :param batch_size: The batch size, defaults to 256 :type batch_size: int, optional :param num_epochs: The number of training epochs, defaults to 50 :type num_epochs: int, optional :param n_splits: The number of splits for gradient accumulation, defaults to 1 :type n_splits: int, optional :param lr: The learning rate, defaults to 0.01 :type lr: float, optional :param lr_scheduler_step_size: The step size for the learning rate scheduler, defaults to 20 :type lr_scheduler_step_size: int, optional :param lr_scheduler_gamma: The gamma for the learning rate scheduler, defaults to 0.2 :type lr_scheduler_gamma: float, optional :param ema_rate: The rate for the exponential moving average, defaults to 0.9999 :type ema_rate: float, optional """ self._model_name = model_name self._test_data = test_data self._val_data = val_data self._num_classes = len(self._test_data.metadata.label_info) self._batch_size = batch_size self._num_epochs = num_epochs self._n_splits = n_splits self._lr = lr self._lr_scheduler_step_size = lr_scheduler_step_size self._lr_scheduler_gamma = lr_scheduler_gamma self._ema_rate = ema_rate self._model_params = model_params self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._num_channels = self._test_data.data_frame[IMAGE_DATA_COLUMN_NAME].values[0].shape[2] self._image_size = self._test_data.data_frame[IMAGE_DATA_COLUMN_NAME].values[0].shape[0]
[docs] def _get_images_and_label_from_data(self, data): """Getting images and labels from the data. :param data: The data object :type data: :py:class:`pe.data.Data` :return: The images and labels :rtype: tuple[np.ndarray, np.ndarray] """ if data is None: return None, None else: images = np.stack(data.data_frame[IMAGE_DATA_COLUMN_NAME].values) images = images.transpose((0, 3, 1, 2)) / 255.0 labels = np.array(data.data_frame[LABEL_ID_COLUMN_NAME].values) return images, labels
[docs] def _get_model(self): """Getting the model. :raises ValueError: If the model name is unknown :return: The model :rtype: torch.nn.Module """ if self._model_name == "wrn": model = WideResNet( in_c=self._num_channels, img_size=self._image_size, num_classes=self._num_classes, depth=28, widen_factor=10, dropRate=0.3, **self._model_params, ) elif self._model_name == "resnet": model = ResNet( in_c=self._num_channels, img_size=self._image_size, num_classes=self._num_classes, depth=164, block_name="BasicBlock", **self._model_params, ) elif self._model_name == "resnext": model = ResNeXt( in_c=self._num_channels, img_size=self._image_size, cardinality=8, depth=28, num_classes=self._num_classes, widen_factor=10, dropRate=0.3, **self._model_params, ) else: raise ValueError(f"Unknown model name: {self._model_name}") return model
[docs] def _get_data_loader(self, data): """Getting the data loader. :param data: The data object :type data: :py:class:`pe.data.Data` :return: The data loader :rtype: torch.utils.data.DataLoader """ images, labels = self._get_images_and_label_from_data(data) if images is None: return None else: return DataLoader( TensorDataset(torch.from_numpy(images).float(), torch.from_numpy(labels).long()), shuffle=True, batch_size=self._batch_size // self._n_splits, )
[docs] @torch.no_grad() def evaluate(self, model, ema, data_loader, criterion): """Evaluating the model. :param model: The model :type model: torch.nn.Module :param ema: The exponential moving average object :type ema: :py:class:`pe.callback.image.dpimagebench_lib.ema.ExponentialMovingAverage` :param data_loader: The data loader :type data_loader: torch.utils.data.DataLoader :param criterion: The criterion :type criterion: torch.nn.Module :return: The accuracy and loss :rtype: tuple[float, float] """ model.eval() ema.store(model.parameters()) ema.copy_to(model.parameters()) total = 0 correct = 0 loss = 0 num_batches = 0 for inputs, targets in data_loader: inputs, targets = inputs.to(self._device) * 2.0 - 1.0, targets.to(self._device) outputs = model(inputs) loss += criterion(outputs, targets).item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() num_batches += 1 ema.restore(model.parameters()) return correct / total * 100, loss / num_batches
[docs] def __call__(self, syn_data): """This function is called after each PE iteration that computes the downstream classification metrics. :param syn_data: The synthetic data :type syn_data: :py:class:`pe.data.Data` :return: The classification accuracy metrics :rtype: list[:py:class:`pe.metric_item.FloatListMetricItem`] """ execution_logger.info(f"Evaluating DPImageBench classification accuracy using {self._model_name}") model = self._get_model() optimizer = optim.Adam(model.parameters(), lr=self._lr) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=self._lr_scheduler_step_size, gamma=self._lr_scheduler_gamma ) criterion = nn.CrossEntropyLoss() model = model.to(self._device) ema = ExponentialMovingAverage(model.parameters(), self._ema_rate) train_loader = self._get_data_loader(syn_data) val_loader = self._get_data_loader(self._val_data) test_loader = self._get_data_loader(self._test_data) grad_accu_step = 0 train_acc_list = [] train_loss_list = [] test_acc_list = [] test_loss_list = [] val_acc_list = [] val_loss_list = [] for epoch in range(self._num_epochs): model.train() train_loss = 0 train_total = 0 train_correct = 0 train_num_batches = 0 for inputs, targets in train_loader: inputs, targets = inputs.to(self._device) * 2.0 - 1.0, targets.to(self._device) if grad_accu_step == 0: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) train_loss += loss.item() (loss / self._n_splits).backward() grad_accu_step += 1 if grad_accu_step == self._n_splits: optimizer.step() grad_accu_step = 0 ema.update(model.parameters()) _, predicted = outputs.max(1) train_total += targets.size(0) train_correct += predicted.eq(targets).sum().item() train_num_batches += 1 scheduler.step() train_acc = train_correct / train_total * 100 train_loss = train_loss / train_num_batches train_acc_list.append(train_acc) train_loss_list.append(train_loss) val_acc, val_loss = self.evaluate(model=model, ema=ema, data_loader=val_loader, criterion=criterion) val_acc_list.append(val_acc) val_loss_list.append(val_loss) test_acc, test_loss = self.evaluate(model=model, ema=ema, data_loader=test_loader, criterion=criterion) test_acc_list.append(test_acc) test_loss_list.append(test_loss) execution_logger.info( f"Epoch {epoch + 1}/{self._num_epochs}, " f"Train Acc: {train_acc:.2f}%, Train Loss: {train_loss:.4f}, " f"Val Acc: {val_acc:.2f}%, Val Loss: {val_loss:.4f}, " f"Test Acc: {test_acc:.2f}%, Test Loss: {test_loss:.4f}" ) metric_items = [ FloatListMetricItem(name=f"{self._model_name}_train_acc", value=train_acc_list), FloatListMetricItem(name=f"{self._model_name}_train_loss", value=train_loss_list), FloatListMetricItem(name=f"{self._model_name}_val_acc", value=val_acc_list), FloatListMetricItem(name=f"{self._model_name}_val_loss", value=val_loss_list), FloatListMetricItem(name=f"{self._model_name}_test_acc", value=test_acc_list), FloatListMetricItem(name=f"{self._model_name}_test_loss", value=test_loss_list), ] return metric_items