Source code for archai.supergraph.algos.divnas.wmr

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import numpy as np


[docs]class Wmr: """ Implements the Randomized Weighted Majority algorithm by Littlestone and Warmuth We use the version in Fig 1 in The Multiplicative Weight Update with the gain version """ def __init__(self, num_items:int, eta:float): assert num_items > 0 assert eta >= 0.0 and eta <= 0.5 self._num_items = num_items self._eta = eta self._weights = self._normalize(np.ones(self._num_items)) self._round_counter = 0 @property def weights(self): return self._weights def _normalize(self, weights:np.array)->None: return weights / np.sum(weights)
[docs] def update(self, rewards:np.array)->None: assert len(rewards.shape) == 1 assert rewards.shape[0] == self._num_items assert np.all(rewards) >= -1 and np.all(rewards) <= 1.0 # # annealed learning rate # self._round_counter += 1 # eta = self._eta / np.sqrt(self._round_counter) eta = self._eta self._weights = self._weights * (1.0 + eta * rewards) self._weights = self._normalize(self._weights)
[docs] def sample(self)->int: return np.random.choice(self._num_items, p=self._normalize(self._weights))