Module flow.modules

Miscellaneous Flows.

Expand source code
"""
Miscellaneous Flows.
"""

import numpy as np
import torch
from torch import nn

from .flow import Flow
from .utils import softplus, softplus_inv, logsigmoid


class Affine(Flow):
    """Learnable Affine Flow.

    Applies weight[i] * x[i] + bias[i], 
    where weight and bias are learnable parameters.
    """

    def __init__(self, weight=None, bias=None, **kwargs):
        """
        Args:
            weight (torch.Tensor): initial value for the weight parameter. 
                If None, initialized to torch.ones(1, self.dim).
            bias (torch.Tensor): initial value for the bias parameter. 
                If None, initialized to torch.zeros(1, self.dim).
        """
        super().__init__(**kwargs)

        if weight is None:
            weight = torch.ones(1, self.dim)

        assert (weight > 0).all()
        self.log_weight = nn.Parameter(torch.log(weight))

        if bias is None:
            bias = torch.zeros(1, self.dim)

        self.bias = nn.Parameter(bias)

    def _log_det(self):
        """Used to compute _log_det for _transform."""
        return self.log_weight.sum(dim=1)

    def _h(self):
        """Compute the parameters for this flow."""
        return torch.exp(self.log_weight), self.bias

    def _transform(self, x, log_det=False, **kwargs):
        weight, bias = self._h()

        u = weight * x + bias

        if log_det:
            return u, self._log_det()
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        weight, bias = self._h()

        x = (u - bias) / weight

        if log_det:
            return x, -self._log_det()
        else:
            return x


class Sigmoid(Flow):
    """Sigmoid Flow."""

    def __init__(self, alpha=1., eps=1e-2, **kwargs):
        r"""
        Args:
            alpha (float): alpha parameter for the sigmoid function: 
                \(s(x, \alpha) = \frac{1}{1 + e^{-\alpha x}}\).
                Must be bigger than 0.
            eps (float): transformed values will be clamped to (eps, 1 - eps) 
                on both _transform and _invert.
        """
        super().__init__(**kwargs)

        self.alpha = alpha
        self.eps = eps

    def _log_det(self, x):
        """Return log|det J_T|, where T: x -> u."""
        return (
            np.log(self.alpha) + 
            2 * logsigmoid(x, alpha=self.alpha) +
            -self.alpha * x
        ).sum(dim=1)

    # Override methods
    def _transform(self, x, log_det=False, **kwargs):
        u = torch.sigmoid(self.alpha * x)
        u = u.clamp(self.eps, 1 - self.eps)

        if log_det:
            return u, self._log_det(x)
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        u = u.clamp(self.eps, 1 - self.eps)
        x = -torch.log(1 / self.alpha / u - 1)

        if log_det:
            return x, -self._log_det(x)
        else:
            return x


class Softplus(Flow):
    """Softplus Flow."""

    def __init__(self, threshold=20., eps=1e-6, **kwargs):
        """
        Args:
            threshold (float): values above this revert to a linear function. 
                Default: 20.
            eps (float): lower-bound to the softplus output.
        """
        super().__init__(**kwargs)

        assert threshold > 0 and eps > 0
        self.threshold = threshold
        self.eps = eps

    def _log_det(self, x):
        return logsigmoid(x).sum(dim=1)

    # Override methods
    def _transform(self, x, log_det=False, **kwargs):
        u = softplus(x, threshold=self.threshold, eps=self.eps)

        if log_det:
            return u, self._log_det(x)
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        x = softplus_inv(u, threshold=self.threshold, eps=self.eps)

        if log_det:
            return x, -self._log_det(x)
        else:
            return x 


class LogSigmoid(Flow):
    """LogSigmoid Flow, defined for numerical stability."""

    def __init__(self, alpha=1., **kwargs):
        """
        Args:
            alpha (float): alpha parameter used by the `Sigmoid`.
        """
        super().__init__(**kwargs)

        self.alpha = alpha

    def _log_det(self, x):
        """Return log|det J_T|, where T: x -> u."""

        return logsigmoid(-self.alpha * x).sum(dim=1) + np.log(self.alpha)

    # Override methods
    def _transform(self, x, log_det=False, **kwargs):
        u = logsigmoid(x, alpha=self.alpha)

        if log_det:
            return u, self._log_det(x)
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        x = -softplus_inv(-u) / self.alpha

        if log_det:
            return x, -self._log_det(x)
        else:
            return x
    

class LeakyReLU(Flow):
    """LeakyReLU Flow."""

    def __init__(self, negative_slope=0.01, **kwargs):
        """
        Args:
            negative_slope (float): slope used for those x < 0,
        """
        super().__init__(**kwargs)

        self.negative_slope = negative_slope

    def _log_det(self, x):
        return torch.where(
            x >= 0, 
            torch.zeros_like(x), 
            torch.ones_like(x) * np.log(self.negative_slope)
        ).sum(dim=1)


    # Override methods
    def _transform(self, x, log_det=False, **kwargs):
        u = torch.where(x >= 0, x, x * self.negative_slope)

        if log_det:
            return u, self._log_det(x)
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        x = torch.where(u >= 0, u, u / self.negative_slope)

        if log_det:
            return x, -self._log_det(x)
        else:
            return x


class BatchNorm(Flow):
    """Perform BatchNormalization as a Flow class.

    If not affine, just learns batch statistics to normalize the input.
    """

    @property
    def affine(self):
        return self._affine.item()
    

    def __init__(self, affine=True, momentum=.1, eps=1e-5, **kwargs):
        """
        Args:
            affine (bool): whether to learn parameters loc/scale.
            momentum (float): value used for the moving average
                of batch statistics. Must be between 0 and 1.
            eps (float): lower-bound for the scale tensor.
        """
        super().__init__(**kwargs)

        assert 0 <= momentum and momentum <= 1

        self.register_buffer('eps', torch.tensor(eps))
        self.register_buffer('momentum', torch.tensor(momentum))

        self.register_buffer('updates', torch.tensor(0))

        self.register_buffer('batch_loc', torch.zeros(1, self.dim))
        self.register_buffer('batch_scale', torch.ones(1, self.dim))

        assert isinstance(affine, bool)
        self.register_buffer('_affine', torch.tensor(affine))

        # We'll save these two parameters even if _affine is not True
        # because, otherwise, when we load the flow,
        # if affine has not the same value as the state_dict, 
        # it will raise an Exception.
        self.loc = nn.Parameter(torch.zeros(1, self.dim))
        self.log_scale = nn.Parameter(torch.zeros(1, self.dim))

    def warm_start(self, x):
        with torch.no_grad():
            self.batch_loc = x.mean(0, keepdim=True)
            self.batch_scale = x.std(0, keepdim=True) + self.eps

            self.updates.data = torch.tensor(1).to(self.device)

        return self

    def _activation(self, x=None, update=None):
        if self.training:
            assert x is not None and x.size(0) >= 2, \
                'If training BatchNorm, pass more than 1 sample.'

            bloc = x.mean(0, keepdim=True)
            bscale = x.std(0, keepdim=True) + self.eps

            # Update self.batch_loc, self.batch_scale
            with torch.no_grad():
                if self.updates.data == 0:
                    self.batch_loc.data = bloc
                    self.batch_scale.data = bscale
                else:
                    m = self.momentum
                    self.batch_loc.data = (1 - m) * self.batch_loc + m * bloc
                    self.batch_scale.data = \
                        (1 - m) * self.batch_scale + m * bscale

                self.updates += 1
        else:
            bloc, bscale = self.batch_loc, self.batch_scale

        loc, scale = self.loc, self.log_scale

        scale = torch.exp(scale) + self.eps
        # Note that batch_scale does not use activation,
        # since it is already in scale units.

        return bloc, bscale, loc, scale

    def _log_det(self, bscale):
        if self.affine:
            return (self.log_scale - torch.log(bscale)).sum(dim=1)
        else:
            return -torch.log(bscale).sum(dim=1)

    def _transform(self, x, log_det=False, **kwargs):
        bloc, bscale, loc, scale = self._activation(x)
        u = (x - bloc) / bscale 
        if self.affine:
            u = u * scale + loc
        
        if log_det:
            log_det = self._log_det(bscale)
            return u, log_det
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        assert not self.training, (
            'If using BatchNorm in reverse training mode, '
            'remember to call it reversed: inv_flow(BatchNorm)(dim=dim)'
        )

        bloc, bscale, loc, scale = self._activation()
        if self.affine:
            x = (u - loc) / scale * bscale + bloc
        else:
            x = u * bscale + bloc
        
        if log_det:
            log_det = -self._log_det(bscale)
            return x, log_det
        else:
            return x


class ActNorm(Affine):
    """Implementation of Activation Normalization.
    https://arxiv.org/pdf/1807.03039.pdf

    Uses Affine implementation and provides the warm_start method
    to initialize Affine so that the transformed distribution
    has location 0 and variance 1.

    Note that ActNorm expects to call warm_start.
    An assert blocks using it in any way before warm_start has been called.
    """

    def __init__(self, eps=1e-6, **kwargs):
        """
        Args:
            eps (float): lower-bound for the weight tensor.
        """
        super().__init__(**kwargs)

        self.register_buffer('eps', torch.tensor(eps))
        self.register_buffer('initialized', torch.tensor(False))

    def warm_start(self, x):
        """Warm start for ActNorm.

        Set loc and weight so that the transformed distribution
        has location 0 and variance 1.
        """

        self.log_weight.data = -torch.log(x.std(0, keepdim=True) + self.eps)
        self.bias.data = -(x * torch.exp(self.log_weight)).mean(0, keepdim=True)

        self.initialized.data = torch.tensor(True).to(self.device)

        return self

    def _h(self):
        assert self.initialized.item()

        return super()._h()


class Shuffle(Flow):
    """Perform a dimension-wise permutation."""
    
    def __init__(self, perm=None, **kwargs):
        """
        Args:
            perm (torch.Tensor): permutation to apply.
        """
        
        super().__init__(**kwargs)
        
        if perm is None:
            perm = torch.randperm(self.dim)
                
        assert perm.shape == (self.dim,)
        self.register_buffer('perm', perm)
        
    def _log_det(self, x):
        # By doing a permutation, det is always 1 or -1. 
        # Hence, log|det| is always 0.
        return torch.zeros_like(x[:, 0])
        
    def _transform(self, x, log_det=False, **kwargs):
        u = x[:, self.perm]
        
        if log_det:
            return u, self._log_det(x)
        else:
            return u
        
    def _invert(self, u, log_det=False, **kwargs):
        inv_perm = torch.argsort(self.perm)
        x = u[:, inv_perm]
        
        if log_det:
            return x, -self._log_det(x)
        else:
            return x

Classes

class ActNorm (eps=1e-06, **kwargs)

Implementation of Activation Normalization. https://arxiv.org/pdf/1807.03039.pdf

Uses Affine implementation and provides the warm_start method to initialize Affine so that the transformed distribution has location 0 and variance 1.

Note that ActNorm expects to call warm_start. An assert blocks using it in any way before warm_start has been called.

Args

eps : float
lower-bound for the weight tensor.
Expand source code
class ActNorm(Affine):
    """Implementation of Activation Normalization.
    https://arxiv.org/pdf/1807.03039.pdf

    Uses Affine implementation and provides the warm_start method
    to initialize Affine so that the transformed distribution
    has location 0 and variance 1.

    Note that ActNorm expects to call warm_start.
    An assert blocks using it in any way before warm_start has been called.
    """

    def __init__(self, eps=1e-6, **kwargs):
        """
        Args:
            eps (float): lower-bound for the weight tensor.
        """
        super().__init__(**kwargs)

        self.register_buffer('eps', torch.tensor(eps))
        self.register_buffer('initialized', torch.tensor(False))

    def warm_start(self, x):
        """Warm start for ActNorm.

        Set loc and weight so that the transformed distribution
        has location 0 and variance 1.
        """

        self.log_weight.data = -torch.log(x.std(0, keepdim=True) + self.eps)
        self.bias.data = -(x * torch.exp(self.log_weight)).mean(0, keepdim=True)

        self.initialized.data = torch.tensor(True).to(self.device)

        return self

    def _h(self):
        assert self.initialized.item()

        return super()._h()

Ancestors

Methods

def warm_start(self, x)

Warm start for ActNorm.

Set loc and weight so that the transformed distribution has location 0 and variance 1.

Expand source code
def warm_start(self, x):
    """Warm start for ActNorm.

    Set loc and weight so that the transformed distribution
    has location 0 and variance 1.
    """

    self.log_weight.data = -torch.log(x.std(0, keepdim=True) + self.eps)
    self.bias.data = -(x * torch.exp(self.log_weight)).mean(0, keepdim=True)

    self.initialized.data = torch.tensor(True).to(self.device)

    return self

Inherited members

class Affine (weight=None, bias=None, **kwargs)

Learnable Affine Flow.

Applies weight[i] * x[i] + bias[i], where weight and bias are learnable parameters.

Args

weight : torch.Tensor
initial value for the weight parameter. If None, initialized to torch.ones(1, self.dim).
bias : torch.Tensor
initial value for the bias parameter. If None, initialized to torch.zeros(1, self.dim).
Expand source code
class Affine(Flow):
    """Learnable Affine Flow.

    Applies weight[i] * x[i] + bias[i], 
    where weight and bias are learnable parameters.
    """

    def __init__(self, weight=None, bias=None, **kwargs):
        """
        Args:
            weight (torch.Tensor): initial value for the weight parameter. 
                If None, initialized to torch.ones(1, self.dim).
            bias (torch.Tensor): initial value for the bias parameter. 
                If None, initialized to torch.zeros(1, self.dim).
        """
        super().__init__(**kwargs)

        if weight is None:
            weight = torch.ones(1, self.dim)

        assert (weight > 0).all()
        self.log_weight = nn.Parameter(torch.log(weight))

        if bias is None:
            bias = torch.zeros(1, self.dim)

        self.bias = nn.Parameter(bias)

    def _log_det(self):
        """Used to compute _log_det for _transform."""
        return self.log_weight.sum(dim=1)

    def _h(self):
        """Compute the parameters for this flow."""
        return torch.exp(self.log_weight), self.bias

    def _transform(self, x, log_det=False, **kwargs):
        weight, bias = self._h()

        u = weight * x + bias

        if log_det:
            return u, self._log_det()
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        weight, bias = self._h()

        x = (u - bias) / weight

        if log_det:
            return x, -self._log_det()
        else:
            return x

Ancestors

  • Flow
  • torch.nn.modules.module.Module

Subclasses

Inherited members

class BatchNorm (affine=True, momentum=0.1, eps=1e-05, **kwargs)

Perform BatchNormalization as a Flow class.

If not affine, just learns batch statistics to normalize the input.

Args

affine : bool
whether to learn parameters loc/scale.
momentum : float
value used for the moving average of batch statistics. Must be between 0 and 1.
eps : float
lower-bound for the scale tensor.
Expand source code
class BatchNorm(Flow):
    """Perform BatchNormalization as a Flow class.

    If not affine, just learns batch statistics to normalize the input.
    """

    @property
    def affine(self):
        return self._affine.item()
    

    def __init__(self, affine=True, momentum=.1, eps=1e-5, **kwargs):
        """
        Args:
            affine (bool): whether to learn parameters loc/scale.
            momentum (float): value used for the moving average
                of batch statistics. Must be between 0 and 1.
            eps (float): lower-bound for the scale tensor.
        """
        super().__init__(**kwargs)

        assert 0 <= momentum and momentum <= 1

        self.register_buffer('eps', torch.tensor(eps))
        self.register_buffer('momentum', torch.tensor(momentum))

        self.register_buffer('updates', torch.tensor(0))

        self.register_buffer('batch_loc', torch.zeros(1, self.dim))
        self.register_buffer('batch_scale', torch.ones(1, self.dim))

        assert isinstance(affine, bool)
        self.register_buffer('_affine', torch.tensor(affine))

        # We'll save these two parameters even if _affine is not True
        # because, otherwise, when we load the flow,
        # if affine has not the same value as the state_dict, 
        # it will raise an Exception.
        self.loc = nn.Parameter(torch.zeros(1, self.dim))
        self.log_scale = nn.Parameter(torch.zeros(1, self.dim))

    def warm_start(self, x):
        with torch.no_grad():
            self.batch_loc = x.mean(0, keepdim=True)
            self.batch_scale = x.std(0, keepdim=True) + self.eps

            self.updates.data = torch.tensor(1).to(self.device)

        return self

    def _activation(self, x=None, update=None):
        if self.training:
            assert x is not None and x.size(0) >= 2, \
                'If training BatchNorm, pass more than 1 sample.'

            bloc = x.mean(0, keepdim=True)
            bscale = x.std(0, keepdim=True) + self.eps

            # Update self.batch_loc, self.batch_scale
            with torch.no_grad():
                if self.updates.data == 0:
                    self.batch_loc.data = bloc
                    self.batch_scale.data = bscale
                else:
                    m = self.momentum
                    self.batch_loc.data = (1 - m) * self.batch_loc + m * bloc
                    self.batch_scale.data = \
                        (1 - m) * self.batch_scale + m * bscale

                self.updates += 1
        else:
            bloc, bscale = self.batch_loc, self.batch_scale

        loc, scale = self.loc, self.log_scale

        scale = torch.exp(scale) + self.eps
        # Note that batch_scale does not use activation,
        # since it is already in scale units.

        return bloc, bscale, loc, scale

    def _log_det(self, bscale):
        if self.affine:
            return (self.log_scale - torch.log(bscale)).sum(dim=1)
        else:
            return -torch.log(bscale).sum(dim=1)

    def _transform(self, x, log_det=False, **kwargs):
        bloc, bscale, loc, scale = self._activation(x)
        u = (x - bloc) / bscale 
        if self.affine:
            u = u * scale + loc
        
        if log_det:
            log_det = self._log_det(bscale)
            return u, log_det
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        assert not self.training, (
            'If using BatchNorm in reverse training mode, '
            'remember to call it reversed: inv_flow(BatchNorm)(dim=dim)'
        )

        bloc, bscale, loc, scale = self._activation()
        if self.affine:
            x = (u - loc) / scale * bscale + bloc
        else:
            x = u * bscale + bloc
        
        if log_det:
            log_det = -self._log_det(bscale)
            return x, log_det
        else:
            return x

Ancestors

  • Flow
  • torch.nn.modules.module.Module

Instance variables

var affine
Expand source code
@property
def affine(self):
    return self._affine.item()

Inherited members

class LeakyReLU (negative_slope=0.01, **kwargs)

LeakyReLU Flow.

Args

negative_slope : float
slope used for those x < 0,
Expand source code
class LeakyReLU(Flow):
    """LeakyReLU Flow."""

    def __init__(self, negative_slope=0.01, **kwargs):
        """
        Args:
            negative_slope (float): slope used for those x < 0,
        """
        super().__init__(**kwargs)

        self.negative_slope = negative_slope

    def _log_det(self, x):
        return torch.where(
            x >= 0, 
            torch.zeros_like(x), 
            torch.ones_like(x) * np.log(self.negative_slope)
        ).sum(dim=1)


    # Override methods
    def _transform(self, x, log_det=False, **kwargs):
        u = torch.where(x >= 0, x, x * self.negative_slope)

        if log_det:
            return u, self._log_det(x)
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        x = torch.where(u >= 0, u, u / self.negative_slope)

        if log_det:
            return x, -self._log_det(x)
        else:
            return x

Ancestors

  • Flow
  • torch.nn.modules.module.Module

Inherited members

class LogSigmoid (alpha=1.0, **kwargs)

LogSigmoid Flow, defined for numerical stability.

Args

alpha : float
alpha parameter used by the Sigmoid.
Expand source code
class LogSigmoid(Flow):
    """LogSigmoid Flow, defined for numerical stability."""

    def __init__(self, alpha=1., **kwargs):
        """
        Args:
            alpha (float): alpha parameter used by the `Sigmoid`.
        """
        super().__init__(**kwargs)

        self.alpha = alpha

    def _log_det(self, x):
        """Return log|det J_T|, where T: x -> u."""

        return logsigmoid(-self.alpha * x).sum(dim=1) + np.log(self.alpha)

    # Override methods
    def _transform(self, x, log_det=False, **kwargs):
        u = logsigmoid(x, alpha=self.alpha)

        if log_det:
            return u, self._log_det(x)
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        x = -softplus_inv(-u) / self.alpha

        if log_det:
            return x, -self._log_det(x)
        else:
            return x

Ancestors

  • Flow
  • torch.nn.modules.module.Module

Inherited members

class Shuffle (perm=None, **kwargs)

Perform a dimension-wise permutation.

Args

perm : torch.Tensor
permutation to apply.
Expand source code
class Shuffle(Flow):
    """Perform a dimension-wise permutation."""
    
    def __init__(self, perm=None, **kwargs):
        """
        Args:
            perm (torch.Tensor): permutation to apply.
        """
        
        super().__init__(**kwargs)
        
        if perm is None:
            perm = torch.randperm(self.dim)
                
        assert perm.shape == (self.dim,)
        self.register_buffer('perm', perm)
        
    def _log_det(self, x):
        # By doing a permutation, det is always 1 or -1. 
        # Hence, log|det| is always 0.
        return torch.zeros_like(x[:, 0])
        
    def _transform(self, x, log_det=False, **kwargs):
        u = x[:, self.perm]
        
        if log_det:
            return u, self._log_det(x)
        else:
            return u
        
    def _invert(self, u, log_det=False, **kwargs):
        inv_perm = torch.argsort(self.perm)
        x = u[:, inv_perm]
        
        if log_det:
            return x, -self._log_det(x)
        else:
            return x

Ancestors

  • Flow
  • torch.nn.modules.module.Module

Inherited members

class Sigmoid (alpha=1.0, eps=0.01, **kwargs)

Sigmoid Flow.

Args

alpha : float
alpha parameter for the sigmoid function: s(x, \alpha) = \frac{1}{1 + e^{-\alpha x}}. Must be bigger than 0.
eps : float
transformed values will be clamped to (eps, 1 - eps) on both _transform and _invert.
Expand source code
class Sigmoid(Flow):
    """Sigmoid Flow."""

    def __init__(self, alpha=1., eps=1e-2, **kwargs):
        r"""
        Args:
            alpha (float): alpha parameter for the sigmoid function: 
                \(s(x, \alpha) = \frac{1}{1 + e^{-\alpha x}}\).
                Must be bigger than 0.
            eps (float): transformed values will be clamped to (eps, 1 - eps) 
                on both _transform and _invert.
        """
        super().__init__(**kwargs)

        self.alpha = alpha
        self.eps = eps

    def _log_det(self, x):
        """Return log|det J_T|, where T: x -> u."""
        return (
            np.log(self.alpha) + 
            2 * logsigmoid(x, alpha=self.alpha) +
            -self.alpha * x
        ).sum(dim=1)

    # Override methods
    def _transform(self, x, log_det=False, **kwargs):
        u = torch.sigmoid(self.alpha * x)
        u = u.clamp(self.eps, 1 - self.eps)

        if log_det:
            return u, self._log_det(x)
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        u = u.clamp(self.eps, 1 - self.eps)
        x = -torch.log(1 / self.alpha / u - 1)

        if log_det:
            return x, -self._log_det(x)
        else:
            return x

Ancestors

  • Flow
  • torch.nn.modules.module.Module

Inherited members

class Softplus (threshold=20.0, eps=1e-06, **kwargs)

Softplus Flow.

Args

threshold : float
values above this revert to a linear function. Default: 20.
eps : float
lower-bound to the softplus output.
Expand source code
class Softplus(Flow):
    """Softplus Flow."""

    def __init__(self, threshold=20., eps=1e-6, **kwargs):
        """
        Args:
            threshold (float): values above this revert to a linear function. 
                Default: 20.
            eps (float): lower-bound to the softplus output.
        """
        super().__init__(**kwargs)

        assert threshold > 0 and eps > 0
        self.threshold = threshold
        self.eps = eps

    def _log_det(self, x):
        return logsigmoid(x).sum(dim=1)

    # Override methods
    def _transform(self, x, log_det=False, **kwargs):
        u = softplus(x, threshold=self.threshold, eps=self.eps)

        if log_det:
            return u, self._log_det(x)
        else:
            return u

    def _invert(self, u, log_det=False, **kwargs):
        x = softplus_inv(u, threshold=self.threshold, eps=self.eps)

        if log_det:
            return x, -self._log_det(x)
        else:
            return x 

Ancestors

  • Flow
  • torch.nn.modules.module.Module

Inherited members