Module flow.transformer
Implementations for Flow-transformers.
Of particular interest are:
Expand source code
"""
Implementations for Flow-transformers.
Of particular interest are:
* `Affine`: affine transformation.
* `DSF`: Deep Sigmoidal Flow.
* `NonAffine`: non-affine transformation.
"""
from functools import partial
import torch
from torch import nn, optim
import torch.nn.functional as F
from .flow import Transformer
from .modules import LogSigmoid, LeakyReLU, softplus_inv
from .utils import *
class Affine(Transformer):
"""Affine Transformer.
"""
def __init__(self, eps=1e-6, **kwargs):
"""
Args:
eps (float): lower-bound for scale parameter.
"""
_h_dim = 2
h_dim = kwargs.pop('h_dim', _h_dim)
assert h_dim == _h_dim, f'Received h_dim={h_dim} but expected {_h_dim}'
super().__init__(h_dim=h_dim, **kwargs)
self.eps = eps
def _log_det(self, scale):
return torch.log(scale).sum(dim=1)
def _activation(self, h):
"""Returns (loc, scale) parameters."""
assert not h.size(1) % self.h_dim
loc, scale = h[:, ::2], h[:, 1::2]
scale = F.softplus(scale) + self.eps
return loc, scale
def _transform(self, x, *h, log_det=False, **kwargs):
loc, scale = h
u = x * scale + loc
if log_det:
return u, self._log_det(scale)
else:
return u
def _invert(self, u, *h, log_det=False, **kwargs):
loc, scale = h
x = (u - loc) / scale
if log_det:
return x, -self._log_det(scale)
else:
return x
def _h_init(self):
h_init = torch.zeros(self.dim * self.h_dim, device=self.device)
h_init[1::2] = softplus_inv(torch.tensor(1. - self.eps)).item()
return h_init
class _IncreasingMonotonicTransformer(Transformer):
"""Abstract Transformer that inverts using Bijection Search,
specific for increasing monotonic transformers.
Note that using this method, inversion will not be differentiable.
Uses `flow.utils.monotonic_increasing_bijective_search`.
"""
def __init__(self, inv_eps=1e-3, inv_steps=1000, **kwargs):
"""
Args:
inv_eps (float): minimum difference between f(u) and x
allowed to stop the inversion.
inv_steps (int): maximum number of iterations
before halting execution. If 0 (default) no maximum defined.
inv_alpha (float): alpha parameter for the inversion method.
"""
super().__init__(**kwargs)
self.inv_eps = inv_eps
self.inv_steps = inv_steps
def _invert(self, u, *h, log_det=False, **kwargs):
x = monotonic_increasing_bijective_search(
# use _transform, but without log_det
self._transform, u, *h, **kwargs,
eps=self.inv_eps, max_steps=self.inv_steps
)
if log_det:
_, log_det = self._transform(u, *h, log_det=True, **kwargs)
return x, -log_det
else:
return x
class _AdamInvTransformer(Transformer):
"""Abstract Transformer that inverts using the Adam optimizer.
Note that using this method, inversion will not be differentiable.
**CAUTION**: for any inheriting Transformers,
if you need to pass tensors as **kwargs to _invert, don't pass them inside
lists or any another collection, pass them directly.
Otherwise, _invert would run through their graph multiple times
and result in an Exception. See _invert for more details.
"""
def __init__(
self,
inv_lr=1e-1, inv_eps=1e-3, inv_steps=1000,
inv_init=None, **kwargs
):
"""
Args:
inv_lr (float): learning rate for the Adam optimizer.
Quite high by default (1e-1) in order to make sampling fast.
For more precision, use inv_lr=1e-3 and inv_steps >= 10000
inv_eps (float): minimum difference between f(u) and x squared
allowed to stop the inversion.
inv_steps (int): maximum number of iterations
before halting execution. If 0 (default) no maximum defined.
inv_init (function): function used to inicialize u.
If None, u = torch.randn_like(x).
"""
super().__init__(reversed=reversed, **kwargs)
self.inv_lr = inv_lr
self.inv_eps = inv_eps
self.inv_steps = inv_steps
self.inv_init = inv_init
def _invert(self, u, *h, log_det=False, **kwargs):
# _invert should be called inside a torch.no_grad(),
# since this operation will not be invertible
with torch.no_grad():
# Avoid running twice through the graph
u = u.clone()
h = tuple(hi.clone() for hi in h)
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
kwargs[k] = v.clone()
if self.inv_init is None:
x = nn.Parameter(torch.randn_like(u))
else:
x = nn.Parameter(self.inv_init(u, *h, **kwargs))
# Howewer, we do need to enable gradients here to use the optimizer.
with torch.enable_grad():
optimizer = optim.Adam([x], lr=self.inv_lr)
for _ in range(self.inv_steps):
loss = (
(u - self._transform(x, *h, **kwargs)) ** 2
).mean()
if loss.item() < self.inv_eps:
break
optimizer.zero_grad()
loss.backward()
optimizer.step()
x = x.data # get the data from the parameter
if log_det:
_, log_det = self._transform(
x, *h, **kwargs, log_det=True
)
log_det = -log_det # we're inverting
return x, log_det
else:
return x
class NonAffine(_AdamInvTransformer):
'''Non-affine transformer.
https://arxiv.org/abs/1912.02762
'''
def __init__(self, k=16, nl=LeakyReLU, eps=1e-6, **kwargs):
"""
Args:
k (int): number of components of the conic combination.
nl (class): non-linearity Flow to use in each component.
Defaults to `flow.modules.LeakyReLU`.
eps (float): lower-bound to strictly-positive h parameters.
"""
_h_dim = 3 * k + 1
h_dim = kwargs.pop('h_dim', _h_dim)
assert h_dim == _h_dim, f'Received h_dim={h_dim} but expected {_h_dim}'
super().__init__(h_dim=h_dim, **kwargs)
self.k = k
self.nl = nl()
self.eps = eps
def _activation(self, h):
"""Returns (weight, loc, scale, bias) parameters."""
assert not h.size(1) % self.h_dim
h = h.view(h.size(0), -1, self.h_dim)
loc, scale, log_weight = h[..., :-1:3], h[..., 1:-1:3], h[..., 2:-1:3]
bias = h[..., -1]
scale = F.softplus(scale) + self.eps
log_weight = F.log_softmax(log_weight, dim=2)
return log_weight, loc, scale, bias
def _transform(self, u, *h, log_det=False, **kwargs):
log_weight, loc, scale, bias = h
z = u.unsqueeze(2) * scale + loc
# We need the derivative of each dimension individually,
# so we need to reshape to (-1, 1) first.
shape = z.shape # save the original shape for later
z = z.view(-1, 1)
nl_res = self.nl(z, log_det=log_det)
if log_det:
nl_z, log_det_i = nl_res
log_det_i = log_det_i.view(*shape) # restore shape
log_det_i = log_sum_exp_trick(
log_weight + log_det_i + torch.log(scale)
).sum(dim=1)
else:
nl_z = nl_res
nl_z = nl_z.view(*shape) # restore shape
x = (nl_z * torch.exp(log_weight)).sum(dim=2) + bias
if log_det:
return x, log_det_i
else:
return x
def _h_init(self):
h_init = torch.zeros(self.dim, self.h_dim, device=self.device)
# loc and bias 0, scale 1
# weight can be random, since all components return the same result
# h_init[:, :-1:3] = 0 # loc
h_init[:, 1:-1:3] = softplus_inv(
torch.tensor(1. - self.eps)
).item() # scale
h_init[:, 2:-1:3] = torch.randn(
self.dim, self.h_dim // 3, device=self.device
) # log_weight
h_init[:, -1] = 0 # bias
return h_init.flatten()
class DSF(_AdamInvTransformer):
"""Deep Sigmoidal Flow.
https://arxiv.org/abs/1804.00779
"""
def __init__(self, k=16, eps=1e-6, alpha=1., **kwargs):
"""
Args:
- k (int): number of components of the conic combination.
- eps (float): lower-bound to strictly-positive h parameters.
- alpha (float): alpha parameter for the sigmoid. Defaults to 1.
"""
_h_dim = 3 * k
h_dim = kwargs.pop('h_dim', _h_dim)
assert h_dim == _h_dim, f'Received h_dim={h_dim} but expected {_h_dim}'
super().__init__(h_dim=h_dim, **kwargs)
self.k = k
self.eps = eps
self.ls = LogSigmoid(dim=self.dim, alpha=alpha, eps=eps)
def _activation(self, h):
"""Returns (loc, scale, w, loc_post, scale_post) parameters."""
assert not h.size(1) % self.h_dim, (h.size(1), self.h_dim)
h = h.view(h.size(0), -1, self.h_dim)
loc, scale, log_w = h[..., ::3], h[..., 1::3], h[..., 2::3]
scale = F.softplus(scale) + self.eps
log_w = F.log_softmax(log_w, dim=2)
return loc, scale, log_w
def _transform(self, x, *h, log_det=False, **kwargs):
# TODO: Avoid computing log_det if not requested
loc, scale, log_w = h
z = scale * x.unsqueeze(2) + loc
# We need the derivative of each dimension individually,
# so we need to reshape to (-1, 1) first.
shape = z.shape # save the original shape for later
z, log_det_z = self.ls(z.view(-1, 1), log_det=True)
# Restore shape
z = z.view(*shape)
log_det_z = log_det_z.view(*shape)
z2 = log_sum_exp_trick(log_w + z) # this removes the 3rd dimension
# Again, we need the derivative of each dimension
shape = z2.shape # save shape
u, log_det_u = self.ls(z2.view(-1, 1), invert=True, log_det=True)
# Restore shape
u = u.view(*shape)
log_det_u = log_det_u.view(*shape)
# Finally, compute log_det if required
if log_det:
log_det = (
log_det_u +
-z2 +
log_sum_exp_trick(
log_w +
z +
log_det_z +
torch.log(scale)
)
).sum(dim=1)
return u, log_det
else:
return u
def _h_init(self):
h_init = torch.zeros(self.dim, self.h_dim, device=self.device)
# loc 0, scale 1
# weight can be random, since all components return the same result
# h_init[:, ::3] = 0 # loc
h_init[:, 1::3] = softplus_inv(
torch.tensor(1. - self.eps)
).item() # scale
h_init[:, 2::3] = torch.randn(
self.dim, self.h_dim // 3, device=self.device
) # log_weight
return h_init.flatten()
Classes
class Affine (eps=1e-06, **kwargs)
-
Affine Transformer.
Args
eps
:float
- lower-bound for scale parameter.
Expand source code
class Affine(Transformer): """Affine Transformer. """ def __init__(self, eps=1e-6, **kwargs): """ Args: eps (float): lower-bound for scale parameter. """ _h_dim = 2 h_dim = kwargs.pop('h_dim', _h_dim) assert h_dim == _h_dim, f'Received h_dim={h_dim} but expected {_h_dim}' super().__init__(h_dim=h_dim, **kwargs) self.eps = eps def _log_det(self, scale): return torch.log(scale).sum(dim=1) def _activation(self, h): """Returns (loc, scale) parameters.""" assert not h.size(1) % self.h_dim loc, scale = h[:, ::2], h[:, 1::2] scale = F.softplus(scale) + self.eps return loc, scale def _transform(self, x, *h, log_det=False, **kwargs): loc, scale = h u = x * scale + loc if log_det: return u, self._log_det(scale) else: return u def _invert(self, u, *h, log_det=False, **kwargs): loc, scale = h x = (u - loc) / scale if log_det: return x, -self._log_det(scale) else: return x def _h_init(self): h_init = torch.zeros(self.dim * self.h_dim, device=self.device) h_init[1::2] = softplus_inv(torch.tensor(1. - self.eps)).item() return h_init
Ancestors
- Transformer
- Flow
- torch.nn.modules.module.Module
Inherited members
class DSF (k=16, eps=1e-06, alpha=1.0, **kwargs)
-
Deep Sigmoidal Flow.
https://arxiv.org/abs/1804.00779
Args
- k (int): number of components of the conic combination.
- eps (float): lower-bound to strictly-positive h parameters.
- alpha (float): alpha parameter for the sigmoid. Defaults to 1.
Expand source code
class DSF(_AdamInvTransformer): """Deep Sigmoidal Flow. https://arxiv.org/abs/1804.00779 """ def __init__(self, k=16, eps=1e-6, alpha=1., **kwargs): """ Args: - k (int): number of components of the conic combination. - eps (float): lower-bound to strictly-positive h parameters. - alpha (float): alpha parameter for the sigmoid. Defaults to 1. """ _h_dim = 3 * k h_dim = kwargs.pop('h_dim', _h_dim) assert h_dim == _h_dim, f'Received h_dim={h_dim} but expected {_h_dim}' super().__init__(h_dim=h_dim, **kwargs) self.k = k self.eps = eps self.ls = LogSigmoid(dim=self.dim, alpha=alpha, eps=eps) def _activation(self, h): """Returns (loc, scale, w, loc_post, scale_post) parameters.""" assert not h.size(1) % self.h_dim, (h.size(1), self.h_dim) h = h.view(h.size(0), -1, self.h_dim) loc, scale, log_w = h[..., ::3], h[..., 1::3], h[..., 2::3] scale = F.softplus(scale) + self.eps log_w = F.log_softmax(log_w, dim=2) return loc, scale, log_w def _transform(self, x, *h, log_det=False, **kwargs): # TODO: Avoid computing log_det if not requested loc, scale, log_w = h z = scale * x.unsqueeze(2) + loc # We need the derivative of each dimension individually, # so we need to reshape to (-1, 1) first. shape = z.shape # save the original shape for later z, log_det_z = self.ls(z.view(-1, 1), log_det=True) # Restore shape z = z.view(*shape) log_det_z = log_det_z.view(*shape) z2 = log_sum_exp_trick(log_w + z) # this removes the 3rd dimension # Again, we need the derivative of each dimension shape = z2.shape # save shape u, log_det_u = self.ls(z2.view(-1, 1), invert=True, log_det=True) # Restore shape u = u.view(*shape) log_det_u = log_det_u.view(*shape) # Finally, compute log_det if required if log_det: log_det = ( log_det_u + -z2 + log_sum_exp_trick( log_w + z + log_det_z + torch.log(scale) ) ).sum(dim=1) return u, log_det else: return u def _h_init(self): h_init = torch.zeros(self.dim, self.h_dim, device=self.device) # loc 0, scale 1 # weight can be random, since all components return the same result # h_init[:, ::3] = 0 # loc h_init[:, 1::3] = softplus_inv( torch.tensor(1. - self.eps) ).item() # scale h_init[:, 2::3] = torch.randn( self.dim, self.h_dim // 3, device=self.device ) # log_weight return h_init.flatten()
Ancestors
- flow.transformer._AdamInvTransformer
- Transformer
- Flow
- torch.nn.modules.module.Module
Inherited members
class NonAffine (k=16, nl=flow.modules.LeakyReLU, eps=1e-06, **kwargs)
-
Non-affine transformer.
https://arxiv.org/abs/1912.02762
Args
k
:int
- number of components of the conic combination.
nl
:class
- non-linearity Flow to use in each component.
Defaults to
LeakyReLU
. eps
:float
- lower-bound to strictly-positive h parameters.
Expand source code
class NonAffine(_AdamInvTransformer): '''Non-affine transformer. https://arxiv.org/abs/1912.02762 ''' def __init__(self, k=16, nl=LeakyReLU, eps=1e-6, **kwargs): """ Args: k (int): number of components of the conic combination. nl (class): non-linearity Flow to use in each component. Defaults to `flow.modules.LeakyReLU`. eps (float): lower-bound to strictly-positive h parameters. """ _h_dim = 3 * k + 1 h_dim = kwargs.pop('h_dim', _h_dim) assert h_dim == _h_dim, f'Received h_dim={h_dim} but expected {_h_dim}' super().__init__(h_dim=h_dim, **kwargs) self.k = k self.nl = nl() self.eps = eps def _activation(self, h): """Returns (weight, loc, scale, bias) parameters.""" assert not h.size(1) % self.h_dim h = h.view(h.size(0), -1, self.h_dim) loc, scale, log_weight = h[..., :-1:3], h[..., 1:-1:3], h[..., 2:-1:3] bias = h[..., -1] scale = F.softplus(scale) + self.eps log_weight = F.log_softmax(log_weight, dim=2) return log_weight, loc, scale, bias def _transform(self, u, *h, log_det=False, **kwargs): log_weight, loc, scale, bias = h z = u.unsqueeze(2) * scale + loc # We need the derivative of each dimension individually, # so we need to reshape to (-1, 1) first. shape = z.shape # save the original shape for later z = z.view(-1, 1) nl_res = self.nl(z, log_det=log_det) if log_det: nl_z, log_det_i = nl_res log_det_i = log_det_i.view(*shape) # restore shape log_det_i = log_sum_exp_trick( log_weight + log_det_i + torch.log(scale) ).sum(dim=1) else: nl_z = nl_res nl_z = nl_z.view(*shape) # restore shape x = (nl_z * torch.exp(log_weight)).sum(dim=2) + bias if log_det: return x, log_det_i else: return x def _h_init(self): h_init = torch.zeros(self.dim, self.h_dim, device=self.device) # loc and bias 0, scale 1 # weight can be random, since all components return the same result # h_init[:, :-1:3] = 0 # loc h_init[:, 1:-1:3] = softplus_inv( torch.tensor(1. - self.eps) ).item() # scale h_init[:, 2:-1:3] = torch.randn( self.dim, self.h_dim // 3, device=self.device ) # log_weight h_init[:, -1] = 0 # bias return h_init.flatten()
Ancestors
- flow.transformer._AdamInvTransformer
- Transformer
- Flow
- torch.nn.modules.module.Module
Inherited members