[1]:
from random import Random
import torch
from torch import nn

Config Search Spaces#

As seen before, discrete search spaces in Archai are defined using the DiscreteSearchSpace abstract class. This tutorial shows how to use the Config Search Space API, which allows building search spaces automatically without having to subclass DiscreteSearchSpace .

Let’s first start with a simple Pytorch model

[2]:
class MyConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, kernel_size=3):
        super().__init__()

        self.op = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding='same'),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self, x):
        return self.op(x)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.stem_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=4, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        self.layers = nn.Sequential(*[
            MyConvBlock(32, 32)
            for i in range(5)
        ])

    def forward(self, x):
        return self.layers(self.stem_conv(x))
[3]:
model = MyModel()
[4]:
x = torch.randn(2, 3, 64, 64)
model.forward(x).shape
[4]:
torch.Size([2, 32, 16, 16])

Creating an ArchParamTree#

To turn this model into a search space, first we need to define an ArchParamTree with the architecture parameters we want to search

[5]:
from archai.discrete_search.search_spaces.config import ArchParamTree, ArchConfig, DiscreteChoice


arch_param_tree = {
    'conv_kernel_size': DiscreteChoice([3, 5, 7]),
    'num_ch': DiscreteChoice([8, 16, 32]),
    'num_layers': DiscreteChoice(range(1, 6))
}

arch_param_tree = ArchParamTree(arch_param_tree)

ArchParamTree are used to generate ArchConfig objects, that specify the chosen architecture configuration. We can sample a configuration using arch_param_tree.sample_config()

[6]:
arch_config = arch_param_tree.sample_config()
arch_config
[6]:
ArchConfig({
    "conv_kernel_size": 7,
    "num_ch": 16,
    "num_layers": 4
})

ArchConfig objects behave like dictionaries. To get the value of an arch parameter, just call arch_config.pick(parameter_name)

[7]:
arch_config.pick('conv_kernel_size')
[7]:
3
[8]:
arch_config.pick('num_ch')
[8]:
8
[9]:
arch_config.to_dict()
[9]:
OrderedDict([('conv_kernel_size', 7), ('num_ch', 16), ('num_layers', 4)])

Let’s use this in our Pytorch Model definition:

[10]:
class MyModel(nn.Module):

    # **We add arch_config as the first parameter of the module**
    def __init__(self, arch_config: ArchConfig):
        super().__init__()

        # **We call arch_config.pick('num_ch')**
        num_ch = arch_config.pick('num_ch')

        self.stem_conv = nn.Sequential(
            nn.Conv2d(3, num_ch, kernel_size=3, stride=4, padding=1),
            nn.BatchNorm2d(num_ch),
            nn.ReLU()
        )

        self.layers = nn.Sequential(*[
            # **We pick the kernel size and number of layers**
            MyConvBlock(num_ch, num_ch, kernel_size=arch_config.pick('conv_kernel_size'))
            for i in range(arch_config.pick('num_layers'))
        ])

    def forward(self, x):
        return self.layers(self.stem_conv(x))

[11]:
model = MyModel(arch_config)
model
[11]:
MyModel(
  (stem_conv): Sequential(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layers): Sequential(
    (0): MyConvBlock(
      (op): Sequential(
        (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (1): MyConvBlock(
      (op): Sequential(
        (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (2): MyConvBlock(
      (op): Sequential(
        (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (3): MyConvBlock(
      (op): Sequential(
        (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
  )
)

To get an Archai DiscreteSearchSpace, we just pass MyModel and search_param_tree to ConfigSearchSpace:

[12]:
from archai.discrete_search.search_spaces.config import ConfigSearchSpace

search_space = ConfigSearchSpace(MyModel, arch_param_tree, mutation_prob=0.3)

All the methods from DiscreteSearchSpace, EvolutionarySearchSpace and BayesOptSearchSpace are automatically implemented.

[13]:
# Randomly samples a model
m = search_space.random_sample()
print(m.archid)

# Mutates a model
m2 = search_space.mutate(m)
print(m2.archid)

# Crossover
m3 = search_space.crossover([search_space.random_sample(), search_space.random_sample()])
print(m3.archid)

# Encode
print(search_space.encode(m3))
307525215b21f510fb6ba1570c71126274e60167
307525215b21f510fb6ba1570c71126274e60167
cc3aba2e903b62619035a871ff3bcdc65dc151de
[3. 8. 1.]

Saving and loading

[14]:
search_space.save_arch(m3, 'arch.json')
m = search_space.load_arch('arch.json')
[15]:
!cat arch.json
{
    "conv_kernel_size": 3,
    "num_ch": 8,
    "num_layers": 1
}

We can now use this with any Archai search algorithm and objective!

More features of ArchParamTrees#

Nesting dictionaries inside an ArchParamTree#

[16]:
arch_param_tree = {
    # Stem convolution architecture
    'stem_config': {
        'kernel_size': DiscreteChoice([3, 5, 7])
    },

    'conv_kernel_size': DiscreteChoice([3, 5, 7]),
    'num_ch': DiscreteChoice([8, 16, 32])
}

arch_param_tree = ArchParamTree(arch_param_tree)
[17]:
c = arch_param_tree.sample_config()
c
[17]:
ArchConfig({
    "stem_config": {
        "kernel_size": 3
    },
    "conv_kernel_size": 3,
    "num_ch": 8
})

Calling c.pick for a parameter containing a dictionary returns a new ArchConfig object for that dictionary

[18]:
c.pick('stem_config')
[18]:
ArchConfig({
    "kernel_size": 3
})
[19]:
c.pick('stem_config').pick('kernel_size')
[19]:
3

Sharing architecture parameters#

We can share configuration of different parts of the architecture by re-using references

[20]:
kernel_size_choice = DiscreteChoice([3, 5, 7])

arch_param_tree = {
    'stem_config': {
        'kernel_size': kernel_size_choice
    },

    'conv_kernel_size': kernel_size_choice,
    'num_ch': DiscreteChoice([8, 16, 32])
}

arch_param_tree = ArchParamTree(arch_param_tree)

conv_kernel_size is now always equal to stem_config.kernel_size

[21]:
arch_param_tree.sample_config()
[21]:
ArchConfig({
    "stem_config": {
        "kernel_size": 3
    },
    "conv_kernel_size": 5,
    "num_ch": 32
})
[22]:
arch_param_tree.sample_config()
[22]:
ArchConfig({
    "stem_config": {
        "kernel_size": 3
    },
    "conv_kernel_size": 5,
    "num_ch": 32
})

Re-using references of entire dictionaries also works

[23]:
stem_config = {
    'kernel_size': DiscreteChoice([3, 5, 7]),
    'stride': DiscreteChoice([2, 4])
}

arch_param_tree = {
    'block1': stem_config,
    'block2': stem_config,
    'block3': stem_config
}

arch_param_tree = ArchParamTree(arch_param_tree)
[24]:
arch_param_tree.sample_config()
[24]:
ArchConfig({
    "block1": {
        "kernel_size": 7,
        "stride": 2
    },
    "block2": {
        "kernel_size": 7,
        "stride": 2
    },
    "block3": {
        "kernel_size": 7,
        "stride": 2
    }
})

Repeating configs a variable number of times#

We can repeat a block of arch parameters using the repeat_config function

[25]:
from archai.discrete_search.search_spaces.config import repeat_config

arch_param_tree = ArchParamTree({
    'layers': repeat_config({
        'kernel_size': DiscreteChoice([1, 3, 5, 7]),
        'residual': DiscreteChoice([False, True]),
        'act_fn': DiscreteChoice(['relu', 'gelu'])
    }, repeat_times=[0, 1, 2], share_arch=False)
})

ArchParamTree will stack 0, 1, 2 or 3 configs inside layers in an ArchConfigList object

[26]:
c = arch_param_tree.sample_config(rng=Random(1))
c
[26]:
ArchConfig({
    "layers": []
})
[27]:
print(len(c.pick('layers')))
0
[28]:
c = arch_param_tree.sample_config(rng=Random(2))
c.pick('layers')
[28]:
ArchConfigList([
    {
        "kernel_size": 7,
        "residual": false,
        "act_fn": "relu"
    },
    {
        "kernel_size": 7,
        "residual": true,
        "act_fn": "gelu"
    }
])
[29]:
print(len(c.pick('layers')))
2

We can select a config from an ArchConfigList by selecting the index of the layer we want

[30]:
# Picks the config of the second layer
print(c.pick('layers')[1])

# Picks the kernel size of the second layer
kernel_size = c.pick('layers')[1].pick('kernel_size')
print(f'kernel_size = {kernel_size}')
ArchConfig({
    "kernel_size": 7,
    "residual": true,
    "act_fn": "gelu"
})
kernel_size = 7

We can also iterate on an ArchConfigList object

[31]:
config = arch_param_tree.sample_config(rng=Random(5))

modules = [
    nn.Conv2d(16, 16, kernel_size=layer_conf.pick('kernel_size'))
    for layer_conf in config.pick('layers')
]
[32]:
modules
[32]:
[Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))]

We can make the architectures parameters the same for each layer by setting share_arch=True,

[33]:
arch_param_tree = ArchParamTree({
    'layers': repeat_config({
        'kernel_size': DiscreteChoice([1, 3, 5, 7]),
        'residual': DiscreteChoice([False, True]),
        'act_fn': DiscreteChoice(['relu', 'gelu'])
    }, repeat_times=[2, 3], share_arch=True)
})

arch_param_tree.sample_config()
[33]:
ArchConfig({
    "layers": [
        {
            "kernel_size": 5,
            "residual": true,
            "act_fn": "gelu"
        },
        {
            "kernel_size": 5,
            "residual": true,
            "act_fn": "gelu"
        },
        {
            "kernel_size": 5,
            "residual": true,
            "act_fn": "gelu"
        }
    ]
})

Example: Building an Image Classification Search Space#

Let’s use the features described above to build the following search space for image classification

Image classification ss

We can build this succinctly using the repeat_config function

[34]:
arch_param_tree = ArchParamTree({
    'base_num_channels': DiscreteChoice([8, 16, 32, 64]),

    'downsample_blocks': repeat_config({
        'max_pool_kernel_size': DiscreteChoice([2, 3]),

        'channel_multiplier': DiscreteChoice([1.0, 1.2, 1.4, 1.6, 1.8, 2.0]),

        'convs': repeat_config({
            'kernel_size': DiscreteChoice([3, 5, 7]),
            'act_fn': DiscreteChoice(['relu', 'gelu']),
        }, repeat_times=[1, 2, 3, 4, 5], share_arch=False)
    }, repeat_times=[1, 2, 3], share_arch=False)
})

# We may want to reduce the search space size by sharing some of the architecture params
# using share_arch=True.
[35]:
class MyConvBlock(nn.Module):
    def __init__(self, arch_config: ArchConfig, in_ch: int, out_ch: int):
        super().__init__()

        self.op = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=arch_config.pick('kernel_size'),
                      padding='same'),
            nn.BatchNorm2d(out_ch),
            nn.ReLU() if arch_config.pick('act_fn') == 'relu' else nn.GELU()
        )

    def forward(self, x):
        return self.op(x)


class MyModel(nn.Module):
    def __init__(self, arch_config: ArchConfig, stem_stride: int = 2):
        super().__init__()

        self.base_ch = arch_config.pick('base_num_channels')

        self.stem_conv = nn.Sequential(
            nn.Conv2d(3, self.base_ch, kernel_size=3, stride=stem_stride, padding=1),
            nn.BatchNorm2d(self.base_ch),
            nn.ReLU()
        )

        self.layers = []
        current_ch = self.base_ch

        for block_cfg in arch_config.pick('downsample_blocks'):
            next_ch = int(block_cfg.pick('channel_multiplier') * current_ch)

            for i, conv_cfg in enumerate(block_cfg.pick('convs')):
                self.layers.append(
                    MyConvBlock(
                        conv_cfg,
                        in_ch=(current_ch if i == 0 else next_ch),
                        out_ch=next_ch
                    )
                )

            self.layers.append(
                nn.MaxPool2d(kernel_size=block_cfg.pick('max_pool_kernel_size'))
            )

            current_ch = next_ch

        self.layers = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.layers(self.stem_conv(x))
[36]:
config = arch_param_tree.sample_config()
[37]:
model = MyModel(config, stem_stride=2)
model(torch.randn(10, 3, 240, 240)).shape
[37]:
torch.Size([10, 128, 60, 60])

We can check the search space size by calling arch_param_tree.num_archs

[38]:
arch_param_tree.num_archs
[38]:
1.218719480020992e+18

Now let’s turn MyModel into a search space object that can be used in Archai

[39]:
ss = ConfigSearchSpace(
    MyModel, arch_param_tree,
    model_kwargs={"stem_stride": 2} # additional kwargs will be passed to MyModel.__init__()
)
[40]:
m = ss.random_sample()
m2 = ss.mutate(m)

# now we can use this search space with any Archai search algorithm
[41]:
print(m2.archid)
d56a2b2d01f75d3f21824f89e5761b4608e6f18e

Tracking used architecture parameters for model de-duplication#

Consider the following example:

[42]:
arch_param_tree = ArchParamTree({
    'op_type': DiscreteChoice(['identity', 'conv']),
    'conv_kernel_size': DiscreteChoice([1, 3, 5, 7])
})
[43]:
class MyOperation(nn.Module):
    def __init__(self, arch_config: ArchConfig, in_ch):
        super().__init__()

        self.op_type = arch_config.pick('op_type')

        if arch_config.pick('op_type') == 'conv':
            self.op = nn.Sequential(
                nn.Conv2d(
                    in_ch, in_ch,
                    kernel_size=arch_config.pick('conv_kernel_size'),
                    padding='same',
                ),
                nn.BatchNorm2d(in_ch),
                nn.ReLU(),
            )

    def forward(self, x):
        if self.op_type == 'identity':
            return x

        return self.op(x)

Notice that when op_type="identity" the value of conv_kernel_size is not used at all.

That means that our search space might not know that the architectures encoded by ("identity", 3) and ("identity", 7) are in fact the same architecture! That can become a huge problem given that each architecture evaluation can be expensive.

To avoid that, each ArchConfig object automatically tracks when an architecture parameter was used with the .pick method.

For instance:

[44]:
c = arch_param_tree.sample_config()
c
[44]:
ArchConfig({
    "op_type": "identity",
    "conv_kernel_size": 3
})

ArchConfig.get_used_params() returns the usage dictionary of this ArchConfig object.

[45]:
c.get_used_params()
[45]:
OrderedDict([('op_type', False), ('conv_kernel_size', False)])

Let’s pick a parameter now

[46]:
c.pick('op_type')
[46]:
'identity'
[47]:
c.get_used_params()
[47]:
OrderedDict([('op_type', True), ('conv_kernel_size', False)])

This is automatically handled by the ConfigSearchSpace object when generating architecture ids, which allows deduplicating architectures

[48]:
ss = ConfigSearchSpace(
    MyOperation, arch_param_tree, model_kwargs={"in_ch": 16}, seed=8
)

Non-used architecture parameters will be encoded using the value passed to unused_param_value (NaN, in our case)

[49]:
m1 = ss.random_sample()
print(f'm1 config = {m1.metadata["config"]}')
print(f'm1 archid = {m1.archid}')
m1 config = ArchConfig({
    "op_type": "identity",
    "conv_kernel_size": 7
})
m1 archid = 260c332c6fc8c6c976736a379f3ae1ac439afd74
[50]:
m2 = ss.random_sample()
print(f'm2 config = {m2.metadata["config"]}')
print(f'm2 archid = {m2.archid}')
m2 config = ArchConfig({
    "op_type": "identity",
    "conv_kernel_size": 5
})
m2 archid = 260c332c6fc8c6c976736a379f3ae1ac439afd74

Notice how m1 and m2 have different value for conv_kernel_size, but since op_type='identity' both are mapped to the same architecture id.

To turn this feature off, you can either

  • Selectively call config.pick(param_name, record_usage=False)

  • or set ConfigSearchSpace(..., track_unused_params=False)

This feature is also automatically used when generating architecture encodings for surrogate models, to make sure equivalent architectures are correctly mapped to the same representation:

[51]:
ss.encode(m1)
[51]:
array([ 1.,  0., -1.])
[52]:
ss.encode(m2)
[52]:
array([ 1.,  0., -1.])