Skip to main content

Tune - PyTorch

This example uses flaml to tune a pytorch model on CIFAR10.

Prepare for tuning

Requirements

pip install torchvision "flaml[blendsearch,ray]"

Before we are ready for tuning, we first need to define the neural network that we would like to tune.

Network Specification

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms


class Net(nn.Module):
def __init__(self, l1=120, l2=84):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

Data

def load_data(data_dir="data"):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

trainset = torchvision.datasets.CIFAR10(
root=data_dir, train=True, download=True, transform=transform
)

testset = torchvision.datasets.CIFAR10(
root=data_dir, train=False, download=True, transform=transform
)

return trainset, testset

Training

from ray import tune


def train_cifar(config, checkpoint_dir=None, data_dir=None):
if "l1" not in config:
logger.warning(config)
net = Net(2 ** config["l1"], 2 ** config["l2"])

device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.cuda.device_count() > 1:
net = nn.DataParallel(net)
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

# The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
# should be restored.
if checkpoint_dir:
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint)
net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)

trainset, testset = load_data(data_dir)

test_abs = int(len(trainset) * 0.8)
train_subset, val_subset = random_split(
trainset, [test_abs, len(trainset) - test_abs]
)

trainloader = torch.utils.data.DataLoader(
train_subset,
batch_size=int(2 ** config["batch_size"]),
shuffle=True,
num_workers=4,
)
valloader = torch.utils.data.DataLoader(
val_subset,
batch_size=int(2 ** config["batch_size"]),
shuffle=True,
num_workers=4,
)

for epoch in range(
int(round(config["num_epochs"]))
): # loop over the dataset multiple times
running_loss = 0.0
epoch_steps = 0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)

# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()
epoch_steps += 1
if i % 2000 == 1999: # print every 2000 mini-batches
print(
"[%d, %5d] loss: %.3f"
% (epoch + 1, i + 1, running_loss / epoch_steps)
)
running_loss = 0.0

# Validation loss
val_loss = 0.0
val_steps = 0
total = 0
correct = 0
for i, data in enumerate(valloader, 0):
with torch.no_grad():
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)

outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

loss = criterion(outputs, labels)
val_loss += loss.cpu().numpy()
val_steps += 1

# Here we save a checkpoint. It is automatically registered with
# Ray Tune and will potentially be passed as the `checkpoint_dir`
# parameter in future iterations.
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((net.state_dict(), optimizer.state_dict()), path)

tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
print("Finished Training")

Test Accuracy

def _test_accuracy(net, device="cpu"):
trainset, testset = load_data()

testloader = torch.utils.data.DataLoader(
testset, batch_size=4, shuffle=False, num_workers=2
)

correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

return correct / total

Hyperparameter Optimization

import numpy as np
import flaml
import os

data_dir = os.path.abspath("data")
load_data(data_dir) # Download data for all trials before starting the run

Search space

max_num_epoch = 100
config = {
"l1": tune.randint(2, 9), # log transformed with base 2
"l2": tune.randint(2, 9), # log transformed with base 2
"lr": tune.loguniform(1e-4, 1e-1),
"num_epochs": tune.loguniform(1, max_num_epoch),
"batch_size": tune.randint(1, 5), # log transformed with base 2
}

Budget and resource constraints

time_budget_s = 600  # time budget in seconds
gpus_per_trial = (
0.5 # number of gpus for each trial; 0.5 means two training jobs can share one gpu
)
num_samples = 500 # maximal number of trials
np.random.seed(7654321)

Launch the tuning

import time

start_time = time.time()
result = flaml.tune.run(
tune.with_parameters(train_cifar, data_dir=data_dir),
config=config,
metric="loss",
mode="min",
low_cost_partial_config={"num_epochs": 1},
max_resource=max_num_epoch,
min_resource=1,
scheduler="asha", # Use asha scheduler to perform early stopping based on intermediate results reported
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
local_dir="logs/",
num_samples=num_samples,
time_budget_s=time_budget_s,
use_ray=True,
)

Check the result

print(f"#trials={len(result.trials)}")
print(f"time={time.time()-start_time}")
best_trial = result.get_best_trial("loss", "min", "all")
print("Best trial config: {}".format(best_trial.config))
print(
"Best trial final validation loss: {}".format(
best_trial.metric_analysis["loss"]["min"]
)
)
print(
"Best trial final validation accuracy: {}".format(
best_trial.metric_analysis["accuracy"]["max"]
)
)

best_trained_model = Net(2 ** best_trial.config["l1"], 2 ** best_trial.config["l2"])
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if gpus_per_trial > 1:
best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device)

checkpoint_value = (
getattr(best_trial.checkpoint, "dir_or_data", None) or best_trial.checkpoint.value
)
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")

model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)

test_acc = _test_accuracy(best_trained_model, device)
print("Best trial test set accuracy: {}".format(test_acc))

Sample of output

#trials=44
time=1193.913584947586
Best trial config: {'l1': 8, 'l2': 8, 'lr': 0.0008818671030627281, 'num_epochs': 55.9513429004283, 'batch_size': 3}
Best trial final validation loss: 1.0694482081472874
Best trial final validation accuracy: 0.6389
Files already downloaded and verified
Files already downloaded and verified
Best trial test set accuracy: 0.6294

Link to notebook | Open in colab