Tune - PyTorch
This example uses flaml to tune a pytorch model on CIFAR10.
Prepare for tuning
Requirements
pip install torchvision "flaml[blendsearch,ray]"
Before we are ready for tuning, we first need to define the neural network that we would like to tune.
Network Specification
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
class Net(nn.Module):
def __init__(self, l1=120, l2=84):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Data
def load_data(data_dir="data"):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = torchvision.datasets.CIFAR10(
root=data_dir, train=True, download=True, transform=transform
)
testset = torchvision.datasets.CIFAR10(
root=data_dir, train=False, download=True, transform=transform
)
return trainset, testset
Training
from ray import tune
def train_cifar(config, checkpoint_dir=None, data_dir=None):
if "l1" not in config:
logger.warning(config)
net = Net(2 ** config["l1"], 2 ** config["l2"])
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.cuda.device_count() > 1:
net = nn.DataParallel(net)
net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
# The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
# should be restored.
if checkpoint_dir:
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint)
net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
trainset, testset = load_data(data_dir)
test_abs = int(len(trainset) * 0.8)
train_subset, val_subset = random_split(
trainset, [test_abs, len(trainset) - test_abs]
)
trainloader = torch.utils.data.DataLoader(
train_subset,
batch_size=int(2 ** config["batch_size"]),
shuffle=True,
num_workers=4,
)
valloader = torch.utils.data.DataLoader(
val_subset,
batch_size=int(2 ** config["batch_size"]),
shuffle=True,
num_workers=4,
)
for epoch in range(
int(round(config["num_epochs"]))
): # loop over the dataset multiple times
running_loss = 0.0
epoch_steps = 0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
epoch_steps += 1
if i % 2000 == 1999: # print every 2000 mini-batches
print(
"[%d, %5d] loss: %.3f"
% (epoch + 1, i + 1, running_loss / epoch_steps)
)
running_loss = 0.0
# Validation loss
val_loss = 0.0
val_steps = 0
total = 0
correct = 0
for i, data in enumerate(valloader, 0):
with torch.no_grad():
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss = criterion(outputs, labels)
val_loss += loss.cpu().numpy()
val_steps += 1
# Here we save a checkpoint. It is automatically registered with
# Ray Tune and will potentially be passed as the `checkpoint_dir`
# parameter in future iterations.
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((net.state_dict(), optimizer.state_dict()), path)
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
print("Finished Training")
Test Accuracy
def _test_accuracy(net, device="cpu"):
trainset, testset = load_data()
testloader = torch.utils.data.DataLoader(
testset, batch_size=4, shuffle=False, num_workers=2
)
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total
Hyperparameter Optimization
import numpy as np
import flaml
import os
data_dir = os.path.abspath("data")
load_data(data_dir) # Download data for all trials before starting the run
Search space
max_num_epoch = 100
config = {
"l1": tune.randint(2, 9), # log transformed with base 2
"l2": tune.randint(2, 9), # log transformed with base 2
"lr": tune.loguniform(1e-4, 1e-1),
"num_epochs": tune.loguniform(1, max_num_epoch),
"batch_size": tune.randint(1, 5), # log transformed with base 2
}
Budget and resource constraints
time_budget_s = 600 # time budget in seconds
gpus_per_trial = (
0.5 # number of gpus for each trial; 0.5 means two training jobs can share one gpu
)
num_samples = 500 # maximal number of trials
np.random.seed(7654321)
Launch the tuning
import time
start_time = time.time()
result = flaml.tune.run(
tune.with_parameters(train_cifar, data_dir=data_dir),
config=config,
metric="loss",
mode="min",
low_cost_partial_config={"num_epochs": 1},
max_resource=max_num_epoch,
min_resource=1,
scheduler="asha", # Use asha scheduler to perform early stopping based on intermediate results reported
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
local_dir="logs/",
num_samples=num_samples,
time_budget_s=time_budget_s,
use_ray=True,
)
Check the result
print(f"#trials={len(result.trials)}")
print(f"time={time.time()-start_time}")
best_trial = result.get_best_trial("loss", "min", "all")
print("Best trial config: {}".format(best_trial.config))
print(
"Best trial final validation loss: {}".format(
best_trial.metric_analysis["loss"]["min"]
)
)
print(
"Best trial final validation accuracy: {}".format(
best_trial.metric_analysis["accuracy"]["max"]
)
)
best_trained_model = Net(2 ** best_trial.config["l1"], 2 ** best_trial.config["l2"])
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if gpus_per_trial > 1:
best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device)
checkpoint_value = (
getattr(best_trial.checkpoint, "dir_or_data", None) or best_trial.checkpoint.value
)
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)
test_acc = _test_accuracy(best_trained_model, device)
print("Best trial test set accuracy: {}".format(test_acc))
Sample of output
#trials=44
time=1193.913584947586
Best trial config: {'l1': 8, 'l2': 8, 'lr': 0.0008818671030627281, 'num_epochs': 55.9513429004283, 'batch_size': 3}
Best trial final validation loss: 1.0694482081472874
Best trial final validation accuracy: 0.6389
Files already downloaded and verified
Files already downloaded and verified
Best trial test set accuracy: 0.6294