ReinMax documentation

ReinMax achieves second-order accuracy and is as fast as the original Straight-Through, which has first-order accuracy.

We reveal that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. Inspired by Heun’s Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which approximates gradient with second-order accuracy with negligible computation overheads. For more details, please check our paper: https://arxiv.org/abs/2304.08612

reinmax.reinmax(logits: torch.Tensor, tau: float)
Parameters
  • logits (torch.Tensor, required) – The input Tensor for the softmax. Note that the softmax operation would be conducted along the last dimension.

  • tau (float, required) – The temperature hyper-parameter. Note note that reinmax prefers to set tau >= 1, while gumbel-softmax prefers to set tau < 1. For more details, please refer to our paper.

Returns

  • y_hard (torch.Tensor) – The one-hot sample generated from multinomial(softmax(logits)).

  • y_soft (torch.Tensor) – The output of the softmax function, i.e., softmax(logits).

Example

Below is an example replacing Straight-Through Gumbel-Softmax with ReinMax

python
1y_hard = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True)
2y_hard, _ = reinmax.reinmax(logits, tau)

Below is an example replacing Straight-Through with ReinMax

python
1y_hard = one_hot_multinomial(logits.softmax())
2y_soft_tau = (logits/tau).softmax()
3y_hard = y_soft_tau - y_soft_tau.detach() + y_hard
4y_hard, y_soft = reinmax.reinmax(logits, tau)