ReinMax documentation¶
ReinMax achieves second-order accuracy and is as fast as the original Straight-Through, which has first-order accuracy.
We reveal that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. Inspired by Heun’s Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which approximates gradient with second-order accuracy with negligible computation overheads. For more details, please check our paper: https://arxiv.org/abs/2304.08612
-
reinmax.
reinmax
(logits: torch.Tensor, tau: float)¶ -
- Parameters
logits (
torch.Tensor
, required) – The input Tensor for the softmax. Note that the softmax operation would be conducted along the last dimension.tau (
float
, required) – The temperature hyper-parameter. Note note that reinmax prefers to set tau >= 1, while gumbel-softmax prefers to set tau < 1. For more details, please refer to our paper.
- Returns
y_hard (
torch.Tensor
) – The one-hot sample generated frommultinomial(softmax(logits))
.y_soft (
torch.Tensor
) – The output of the softmax function, i.e.,softmax(logits)
.
Example
Below is an example replacing Straight-Through Gumbel-Softmax with ReinMax
python1y_hard = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True)2y_hard, _ = reinmax.reinmax(logits, tau)Below is an example replacing Straight-Through with ReinMax
python1y_hard = one_hot_multinomial(logits.softmax())2y_soft_tau = (logits/tau).softmax()3y_hard = y_soft_tau - y_soft_tau.detach() + y_hard4y_hard, y_soft = reinmax.reinmax(logits, tau)