Source code for archai.supergraph.models.shakeshake.shakeshake

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


[docs]class ShakeShake(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, x1, x2, training=True): if training: alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_() alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1) else: alpha = 0.5 return alpha * x1 + (1 - alpha) * x2
[docs] @staticmethod def backward(ctx, grad_output): beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_() beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) beta = Variable(beta) return beta * grad_output, (1 - beta) * grad_output, None
[docs]class Shortcut(nn.Module): def __init__(self, in_ch, out_ch, stride): super(Shortcut, self).__init__() self.stride = stride self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) self.bn = nn.BatchNorm2d(out_ch)
[docs] def forward(self, x): h = F.relu(x) h1 = F.avg_pool2d(h, 1, self.stride) h1 = self.conv1(h1) h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride) h2 = self.conv2(h2) h = torch.cat((h1, h2), dim=1) return self.bn(h)