Source code for causica.distributions.noise.spline.rational_quadratic_transform
"""
Implementation of the Piecewise Rational Quadratic Transform
This is pretty much a copy-paste of
https://github.com/tonyduan/normalizing-flows/blob/master/nf/utils.py
We should consider using the Pyro implementation.
"""
import torch
import torch.distributions as td
from causica.distributions.noise.spline.bayesiains_nsf_rqs import unconstrained_rational_quadratic_spline
[docs]
class PiecewiseRationalQuadraticTransform(td.Transform):
"""
Layer that implements a spline-cdf (https://arxiv.org/abs/1906.04032) transformation.
All dimensions of x are treated as independent, no coupling is used. This is needed
to ensure invertibility in our additive noise SEM.
This is pretty much a copy-paste of
https://github.com/tonyduan/normalizing-flows/blob/master/nf/utils.py
"""
def __init__(self, knot_locations: torch.Tensor, derivatives: torch.Tensor, tail_bound: float = 3.0):
"""
Args:
knot_locations: the x, y points of the knots, shape [dim, num_bins, 2]
derivatives: the derivatives at the knots, shape [dim, num_bins - 1]
tail_bound: distance of edgemost bins relative to 0,
"""
super().__init__()
assert knot_locations.ndim == 3, "Only two dimensional params are supported"
assert knot_locations.shape[-1] == 2, "Knot locations are 2-d"
self.dim, self.num_bins, *_ = knot_locations.shape
assert derivatives.shape == (self.dim, self.num_bins - 1)
self.tail_bound = tail_bound
self._event_dim = 0
self.knot_locations = knot_locations
self.derivatives = derivatives
@property
[docs]
def _piecewise_cdf(self, inputs: torch.Tensor, inverse: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
"""
Evaluate the Cumulative Density function at `inputs`
Args:
inputs: the positions at which to evaluate the cdf shape batch_shape + (input_dim)
inverse: whether this is forwards or backwards transform
Returns:
input_evaluations and absolute log determinants, a tuple of tensors of shape batch_shape + (input_dim)
"""
assert len(inputs.shape) > 1 # TODO(JJ) accept 1d inputs
batch_shape = inputs.shape[:-1]
# shape batch_shape + (dim, 3 * num_bins - 1)
expanded_knot_locations = self.knot_locations[None, ...].expand(*batch_shape, -1, -1, -1)
expanded_derivatives = self.derivatives[None, ...].expand(*batch_shape, -1, -1)
return unconstrained_rational_quadratic_spline(
inputs=inputs,
unnormalized_widths=expanded_knot_locations[..., 0],
unnormalized_heights=expanded_knot_locations[..., 1],
unnormalized_derivatives=expanded_derivatives,
inverse=inverse,
tail_bound=self.tail_bound,
)
@td.constraints.dependent_property(is_discrete=False)
[docs]
def domain(self):
if self.event_dim == 0:
return td.constraints.real
return td.constraints.independent(td.constraints.real, self.event_dim)
@td.constraints.dependent_property(is_discrete=False)
[docs]
def codomain(self):
if self.event_dim == 0:
return td.constraints.real
return td.constraints.independent(td.constraints.real, self.event_dim)
[docs]
def _call(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: batch_shape + (input_dim)
Returns:
transformed_input batch_shape + (input_dim)
"""
return self._piecewise_cdf(x, inverse=False)[0]
[docs]
def _inverse(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y: batch_shape + (input_dim)
Returns:
transformed_input, batch_shape + (input_dim)
"""
return self._piecewise_cdf(y, inverse=True)[0]
[docs]
def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Args:
input: batch_shape + (input_dim)
Returns:
jacobian_log_determinant: batch_shape + (input_dim)
"""
return self._piecewise_cdf(x, inverse=False)[1]