# API Reference This section provides detailed API documentation for all modules in DiffusionLab. - [Diffusions](diffusions/) - [Distributions](distributions/) - [Losses](losses/) - [Models](models/) - [Samplers](samplers/) - [Schedulers](schedulers/) - [Utils](utils/) - [Vector Fields](vector_fields/) # Diffusions This module contains functionality related to diffusions. ## `DiffusionProcess` Base class for implementing various diffusion processes. A diffusion process defines how data evolves over time when noise is added according to specific dynamics. This class provides a framework for implementing different types of diffusion processes used in generative modeling. The diffusion is parameterized by two functions: - alpha(t): Controls how much of the original signal is preserved at time t - sigma(t): Controls how much noise is added at time t The forward process is defined as: x_t = alpha(t) * x_0 + sigma(t) * eps, where: - x_0 is the original data - x_t is the noised data at time t - eps is random noise sampled from a standard Gaussian distribution - t is the diffusion time parameter, typically in range [0, 1] Attributes: | Name | Type | Description | | --- | --- | --- | | `alpha` | `Callable` | Function that determines signal preservation at time t, differentiable, maps any tensor to tensor of same shape | | `sigma` | `Callable` | Function that determines noise level at time t, differentiable, maps any tensor to tensor of same shape | | `alpha_prime` | `Callable` | Derivative of alpha, maps any tensor to tensor of same shape | | `sigma_prime` | `Callable` | Derivative of sigma, maps any tensor to tensor of same shape | ### `alpha = alpha` ### `alpha_prime = scalar_derivative(alpha)` ### `sigma = sigma` ### `sigma_prime = scalar_derivative(sigma)` ### `__init__(**dynamics_hparams)` Initialize a diffusion process with specific dynamics parameters. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `**dynamics_hparams` | `Any` | Keyword arguments containing the dynamics parameters. Must include: - alpha: Callable that maps time t to signal coefficient - sigma: Callable that maps time t to noise coefficient | `{}` | Raises: | Type | Description | | --- | --- | | `AssertionError` | If alpha or sigma is not provided in dynamics_hparams | ### `forward(x, t, eps)` Forward pass of the dynamics model. This method implements the forward diffusion process, which gradually adds noise to the input data according to the specified dynamics (alpha and sigma functions). Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x` | `Tensor` | The input data tensor of shape (N, \*D), where N is the batch size and D represents the data dimensions. | *required* | | `t` | `Tensor` | The time parameter tensor of shape (N,) or broadcastable to x's shape, with values typically in the range [0, 1]. | *required* | | `eps` | `Tensor` | The Gaussian noise tensor of shape (N, \*D), where N is the batch size and D represents the data dimensions. | *required* | Returns: torch.Tensor: The noised data at time t, computed as alpha(t) * x + sigma(t) * eps, of shape (N, \*D) matching the input shape. ## `FlowMatchingProcess` Bases: `DiffusionProcess` Implements a Flow Matching diffusion process. Flow Matching is a technique used in generative modeling where the goal is to learn a continuous transformation (flow) between a simple distribution and a complex data distribution. In this implementation: - alpha(t) = 1 - t - sigma(t) = t This creates a linear interpolation between the original data (at t=0) and the noise (at t=1), which is useful for training flow-based generative models. ### `__init__()` Initialize a Flow Matching diffusion process with predefined dynamics. The process uses: - alpha(t) = 1 - t - sigma(t) = t Both functions map tensors of shape (N,) to tensors of the same shape. This creates a linear interpolation between the original data and noise. ## `OrnsteinUhlenbeckProcess` Bases: `DiffusionProcess` Implements an Ornstein-Uhlenbeck diffusion process. The Ornstein-Uhlenbeck process is a mean-reverting stochastic process that describes the velocity of a particle undergoing Brownian motion while being subject to friction. In this implementation: - alpha(t) = sqrt(1 - t²) - sigma(t) = t This process has properties that make it useful for certain generative modeling tasks, particularly when a smooth transition between clean and noisy states is desired. ### `__init__()` Initialize an Ornstein-Uhlenbeck diffusion process with predefined dynamics. The process uses: - alpha(t) = sqrt(1 - t²) - sigma(t) = t Both functions map tensors of shape (N,) to tensors of the same shape. ## `VarianceExplodingProcess` Bases: `DiffusionProcess` Implements a Variance Exploding (VE) diffusion process. In a VE process, the signal component remains constant (alpha(t) = 1) while the noise component increases according to the provided sigma function. This leads to the variance of the process "exploding" as t increases. The forward process is defined as: x_t = x_0 + sigma(t) * eps This is used in models like NCSN (Noise Conditional Score Network) and Score SDE. ### `__init__(sigma)` Initialize a Variance Exploding diffusion process. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `sigma` | `Callable` | Function that determines how noise scales with time t. Should map a tensor of time values of shape (N,) to noise coefficients of the same shape. | *required* | # Distributions This module contains functionality related to distributions. ## `base` ### `Distribution` Base class for all distributions. This class should be subclassed by other distributions when you want to use ground truth scores, denoisers, noise predictors, or velocity estimators. Each distribution implementation provides methods to compute various vector fields related to the diffusion process, such as denoising (x0), noise prediction (eps), velocity estimation (v), and score estimation. #### `batch_dist_params(N, dist_params)` Add a batch dimension to the distribution parameters. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `N` | `int` | The number of samples in the batch. | *required* | | `dist_params` | `Dict[str, Tensor]` | A dictionary of parameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Dict[str, Tensor]` | A dictionary of parameters for the distribution, with a batch dimension added. | #### `eps(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the noise predictor E[eps | x_t] at a given time t and input x_t, under the data model x_t = alpha(t) * x_0 + sigma(t) * eps where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I). This is stateless for the same reason as the denoiser method. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D), where D is the shape of each data. | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process whose forward and reverse dynamics determine the time-evolution of the vector fields corresponding to the distribution. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary of batched parameters for the distribution. Each parameter is of shape (N, \*P) where P is the shape of the parameter. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of eps, of shape (N, \*D). | Note The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. #### `get_vector_field_method(vector_field_type)` Returns the appropriate method to compute the specified vector field type. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `vector_field_type` | `VectorFieldType` | The type of vector field to compute. | *required* | Returns: | Type | Description | | --- | --- | | `Callable[[Tensor, Tensor, DiffusionProcess, Dict[str, Tensor], Dict[str, Any]], Tensor]` | A method that computes the specified vector field, with signature: | | `Callable[[Tensor, Tensor, DiffusionProcess, Dict[str, Tensor], Dict[str, Any]], Tensor]` | (x_t, t, diffusion_process, batched_dist_params, dist_hparams) -> tensor | Raises: | Type | Description | | --- | --- | | `ValueError` | If the vector field type is not recognized. | #### `sample(N, dist_params, dist_hparams)` Draws N i.i.d. samples from the data distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `N` | `int` | The number of samples to draw. | *required* | | `dist_params` | `Dict[str, Tensor]` | A dictionary of parameters for the distribution. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | A tuple (samples, metadata), where samples is a tensor of shape (N, \*D) and metadata is any additional information. | | `Any` | For example, if the distribution has labels, the metadata is a tensor of shape (N, ) containing the labels. | | `Tuple[Tensor, Any]` | Note that the samples are always placed on the CPU. | #### `score(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the score estimator grad_x log p(x_t, t) at a given time t and input x_t, under the data model x_t = alpha(t) * x_0 + sigma(t) * eps where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I). This is stateless for the same reason as the denoiser method. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D), where D is the shape of each data. | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process whose forward and reverse dynamics determine the time-evolution of the vector fields corresponding to the distribution. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary of batched parameters for the distribution. Each parameter is of shape (N, \*P) where P is the shape of the parameter. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of grad_x log p(x_t, t), of shape (N, \*D). | Note The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. #### `v(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the velocity estimator E[d/dt x_t | x_t] at a given time t and input x_t, under the data model x_t = alpha(t) * x_0 + sigma(t) * eps where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I). This is stateless for the same reason as the denoiser method. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D), where D is the shape of each data. | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process whose forward and reverse dynamics determine the time-evolution of the vector fields corresponding to the distribution. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary of batched parameters for the distribution. Each parameter is of shape (N, \*P) where P is the shape of the parameter. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of d/dt x_t, of shape (N, \*D). | Note The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. #### `validate_hparams(dist_hparams)` Validate the hyperparameters for the distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `None` | None | Throws AssertionError: If the parameters are invalid, the assertion fails at exactly the point of failure. #### `validate_params(possibly_batched_dist_params)` Validate the parameters for the distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `possibly_batched_dist_params` | `Dict[str, Tensor]` | A dictionary of parameters for the distribution. Each value is a PyTorch tensor, possibly having a batch dimension. | *required* | Returns: | Type | Description | | --- | --- | | `None` | None | Throws AssertionError: If the parameters are invalid, the assertion fails at exactly the point of failure. #### `x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the denoiser E[x_0 | x_t] at a given time t and input x_t, under the data model x_t = alpha(t) * x_0 + sigma(t) * eps where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I). Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D), where D is the shape of each data. | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process whose forward and reverse dynamics determine the time-evolution of the vector fields corresponding to the distribution. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary of batched parameters for the distribution. Each parameter is of shape (N, \*P) where P is the shape of the parameter. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of x_0, of shape (N, \*D). | Note The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. ## `empirical` ### `EmpiricalDistribution` Bases: `Distribution` An empirical distribution, i.e., the uniform distribution over a dataset. Formally, the distribution is defined as: mu(B) = (1/N) * sum\_(i=1)^(N) delta(x_i in B) where x_i is the ith data point in the dataset, and N is the number of data points. Distribution Parameters - None Distribution Hyperparameters - labeled_data: A DataLoader of data which spawns the empirical distribution, where each data sample is a (data, label) tuple. Both data and label are PyTorch tensors. Note - This class has no sample() method as it's difficult to sample randomly from a DataLoader. In practice, you can sample directly from the DataLoader and apply filtering there. #### `validate_hparams(dist_hparams)` Validate the hyperparameters for the empirical distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. Must contain 'labeled_data' which is a DataLoader. | *required* | Returns: | Type | Description | | --- | --- | | `None` | None | Throws AssertionError: If the parameters are invalid. #### `x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the denoiser E[x_0 | x_t] for an empirical distribution. This method computes the denoiser by performing a weighted average of the dataset samples, where the weights are determined by the likelihood of x_t given each sample. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D), where D is the shape of each data. | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary of batched parameters for the distribution. Not used for empirical distribution. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. Must contain 'labeled_data' which is a DataLoader. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of x_0, of shape (N, \*D). | ## `gmm` ### `GMMDistribution` Bases: `Distribution` A Gaussian Mixture Model (GMM) with K components. Formally, the distribution is defined as: mu(B) = sum\_(i=1)^(K) pi_i * N(mu_i, Sigma_i)(B) where mu_i is the mean of the ith component, Sigma_i is the covariance matrix of the ith component, and pi_i is the prior probability of the ith component. Distribution Parameters - means: A tensor of shape (K, D) containing the means of the components. - covs: A tensor of shape (K, D, D) containing the covariance matrices of the components. - priors: A tensor of shape (K, ) containing the prior probabilities of the components. Distribution Hyperparameters - None #### `sample(N, dist_params, dist_hparams)` #### `validate_params(possibly_batched_dist_params)` #### `x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the denoiser E[x_0 | x_t] for a GMM distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D). | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary containing the batched parameters of the distribution. - means: A tensor of shape (N, K, D) containing the means of the components. - covs: A tensor of shape (N, K, D, D) containing the covariance matrices of the components. - priors: A tensor of shape (N, K) containing the prior probabilities of the components. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of x_0, of shape (N, D). | Note The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. ### `IsoGMMDistribution` Bases: `Distribution` An isotropic (i.e., spherical variances) Gaussian Mixture Model (GMM) with K components. Formally, the distribution is defined as: mu(B) = sum\_(i=1)^(K) pi_i * N(mu_i, tau_i^2 * I_D)(B) where mu_i is the mean of the ith component, tau is the standard deviation of the spherical variances, and pi_i is the prior probability of the ith component. Distribution Parameters - means: A tensor of shape (K, D) containing the means of the components. - vars: A tensor of shape (K, ) containing the variances of the components. - priors: A tensor of shape (K, ) containing the prior probabilities of the components. Distribution Hyperparameters - None #### `sample(N, dist_params, dist_hparams)` Draws N i.i.d. samples from the isotropic GMM distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `N` | `int` | The number of samples to draw. | *required* | | `dist_params` | `Dict[str, Tensor]` | A dictionary of parameters for the distribution. - means: A tensor of shape (K, D) containing the means of the components. - vars: A tensor of shape (K, ) containing the variances of the components. - priors: A tensor of shape (K, ) containing the prior probabilities of the components. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, ) | | `Tensor` | containing the component indices from which each sample was drawn. | | `Tuple[Tensor, Tensor]` | Note that the samples are always placed on the CPU. | #### `validate_params(possibly_batched_dist_params)` #### `x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the denoiser E[x_0 | x_t] for an isotropic GMM distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D). | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process whose forward and reverse dynamics determine the time-evolution of the vector fields corresponding to the distribution. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary containing the batched parameters of the distribution. - means: A tensor of shape (N, K, D) containing the means of the components. - vars: A tensor of shape (N, K) containing the variances of the components. - priors: A tensor of shape (N, K) containing the prior probabilities of the components. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of x_0, of shape (N, D). | Note The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. ### `IsoHomoGMMDistribution` Bases: `Distribution` An isotropic homoscedastic (i.e., equal spherical variances) Gaussian Mixture Model (GMM) with K components. Formally, the distribution is defined as: mu(B) = sum\_(i=1)^(K) pi_i * N(mu_i, tau^2 * I_D)(B) where mu_i is the mean of the ith component, tau is the standard deviation of the spherical variances, and pi_i is the prior probability of the ith component. Distribution Parameters - means: A tensor of shape (K, D) containing the means of the components. - var: A tensor of shape () containing the variances of the components. - priors: A tensor of shape (K, ) containing the prior probabilities of the components. Distribution Hyperparameters - None #### `sample(N, dist_params, dist_hparams)` Draws N i.i.d. samples from the isotropic homoscedastic GMM distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `N` | `int` | The number of samples to draw. | *required* | | `dist_params` | `Dict[str, Tensor]` | A dictionary of parameters for the distribution. - means: A tensor of shape (K, D) containing the means of the components. - var: A tensor of shape () containing the shared variance of all components. - priors: A tensor of shape (K, ) containing the prior probabilities of the components. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, ) | | `Tensor` | containing the component indices from which each sample was drawn. | | `Tuple[Tensor, Tensor]` | Note that the samples are always placed on the CPU. | #### `validate_params(possibly_batched_dist_params)` #### `x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the denoiser E[x_0 | x_t] for an isotropic homoscedastic GMM distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D). | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process whose forward and reverse dynamics determine the time-evolution of the vector fields corresponding to the distribution. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary containing the batched parameters of the distribution. - means: A tensor of shape (N, K, D) containing the means of the components. - var: A tensor of shape (N, ) containing the shared variance of all components. - priors: A tensor of shape (N, K) containing the prior probabilities of the components. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of x_0, of shape (N, D). | Note The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. ### `LowRankGMMDistribution` Bases: `Distribution` A Gaussian Mixture Model (GMM) with K low-rank components. Formally, the distribution is defined as: mu(B) = sum\_(i=1)^(K) pi_i * N(mu_i, Sigma_i)(B) where mu_i is the mean of the ith component, Sigma_i is the covariance matrix of the ith component, and pi_i is the prior probability of the ith component. Notably, Sigma_i is a low-rank matrix of the form Sigma_i = A_i @ A_i^T Distribution Parameters - means: A tensor of shape (K, D) containing the means of the components. - covs_factors: A tensor of shape (K, D, P) containing the tall factors of the covariance matrices of the components. - priors: A tensor of shape (K, ) containing the prior probabilities of the components. Distribution Hyperparameters - None Note - The covariance matrices are not explicitly stored, but rather computed as Sigma_i = A_i @ A_i^T. - The time and memory complexity is much lower in this class compared to the full GMM class, if and only if each covariance is low-rank (P \<< D). #### `sample(N, dist_params, dist_hparams)` Draws N i.i.d. samples from the low-rank GMM distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `N` | `int` | The number of samples to draw. | *required* | | `dist_params` | `Dict[str, Tensor]` | A dictionary of parameters for the distribution. - means: A tensor of shape (K, D) containing the means of the components. - covs_factors: A tensor of shape (K, D, P) containing the tall factors of the covariance matrices. - priors: A tensor of shape (K, ) containing the prior probabilities of the components. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, ) | | `Tensor` | containing the component indices from which each sample was drawn. | | `Tuple[Tensor, Tensor]` | Note that the samples are always placed on the CPU. | #### `validate_params(possibly_batched_dist_params)` #### `x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)` Computes the denoiser E[x_0 | x_t] for a low-rank GMM distribution. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The input tensor, of shape (N, D). | *required* | | `t` | `Tensor` | The time tensor, of shape (N, ). | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process whose forward and reverse dynamics determine the time-evolution of the vector fields corresponding to the distribution. | *required* | | `batched_dist_params` | `Dict[str, Tensor]` | A dictionary containing the batched parameters of the distribution. - means: A tensor of shape (N, K, D) containing the means of the components. - covs_factors: A tensor of shape (N, K, D, P) containing the tall factors of the covariance matrices. - priors: A tensor of shape (N, K) containing the prior probabilities of the components. | *required* | | `dist_hparams` | `Dict[str, Any]` | A dictionary of hyperparameters for the distribution. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | The prediction of x_0, of shape (N, D). | Note The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. The covariance matrices are implicitly defined as Sigma_i = A_i @ A_i^T, where A_i is the ith factor. # Losses This module contains functionality related to losses. ## `SamplewiseDiffusionLoss` Bases: `Module` Sample-wise loss function for training diffusion models. This class implements various loss functions for diffusion models based on the specified target type. The loss is computed as the mean squared error between the model's prediction and the target, which depends on the chosen vector field type. The loss supports different target types: - X0: Learn to predict the original clean data x_0 - EPS: Learn to predict the noise component eps - V: Learn to predict the velocity field v - SCORE: Not directly supported (raises ValueError) Attributes: | Name | Type | Description | | --- | --- | --- | | `diffusion` | `DiffusionProcess` | The diffusion process defining the forward dynamics | | `target_type` | `VectorFieldType` | The type of target to learn via minimizing the loss function | | `target` | `Callable` | Function that computes the target based on the specified target_type. Takes tensors of shapes (N, D) for x_t, f_x_t, x_0, eps and (N,) for t, and returns a tensor of shape (N, D). | ### `diffusion_process = diffusion_process` ### `target = target` ### `target_type = target_type` ### `__init__(diffusion_process, target_type)` Initialize the diffusion loss function. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `diffusion_process` | `DiffusionProcess` | The diffusion process to use, containing data about the forward evolution. | *required* | | `target_type` | `VectorFieldType` | The type of target to learn via minimizing the loss function. Must be one of VectorFieldType.X0, VectorFieldType.EPS, or VectorFieldType.V. | *required* | Raises: | Type | Description | | --- | --- | | `ValueError` | If target_type is VectorFieldType.SCORE, which is not directly supported. | ### `batchwise_loss_factory(N_noise_draws_per_sample)` Create a batchwise loss function that averages the samplewise loss over multiple noise draws per sample. This factory method returns a function that can be used during training to compute the loss for a batch of data. The returned function handles the process of: 1. Repeating each sample N times to apply different noise realizations 1. Adding noise according to the diffusion process 1. Computing model predictions 1. Calculating and weighting the loss Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `N_noise_draws_per_sample` | `int` | The number of different noise realizations to use for each data sample. Higher values can reduce variance but increase computation. | *required* | Returns: | Type | Description | | --- | --- | | `Callable[[VectorField, Tensor, Tensor, Tensor], Tensor]` | Callable\[[VectorField, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor\]: A function that computes the weighted average loss across a batch with the signature: (vector_field, data, timesteps, sample_weights) -> scalar_loss | ### `forward(x_t, f_x_t, x_0, eps, t)` Compute the loss for each sample in the batch. This method calculates the mean squared error between the model's prediction (f_x_t) and the target value determined by the target_type. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x_t` | `Tensor` | The noised data at time t, of shape (N, \*D) where N is the batch size and D represents the data dimensions. | *required* | | `f_x_t` | `Tensor` | The model's prediction at time t, of shape (N, \*D). | *required* | | `x_0` | `Tensor` | The original clean data, of shape (N, \*D). | *required* | | `eps` | `Tensor` | The noise used to generate x_t, of shape (N, \*D). | *required* | | `t` | `Tensor` | The time parameter, of shape (N,). | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The per-sample loss values, of shape (N,) where N is the batch size. | # Models This module contains functionality related to models. ## `DiffusionModel` Bases: `LightningModule`, `VectorField` A PyTorch Lightning module for training and evaluating diffusion models. This class implements a diffusion model that can be trained using various vector field types (score, x0, eps, v) and diffusion processes. It handles the training loop, loss computation, and evaluation metrics. The model inherits from both LightningModule (for training) and VectorField (for sampling), making it compatible with both the Lightning training framework and the diffusion sampling algorithms. Attributes: | Name | Type | Description | | --- | --- | --- | | `net` | `Module` | The neural network that predicts the vector field. | | `vector_field_type` | `VectorFieldType` | The type of vector field the model predicts. | | `diffusion_process` | `DiffusionProcess` | The diffusion process used for training. | | `train_scheduler` | `Scheduler` | The scheduler for generating training time steps. | | `optimizer` | `Optimizer` | The optimizer for training the model. | | `lr_scheduler` | `LRScheduler` | The learning rate scheduler. | | `batchwise_metrics` | `ModuleDict` | Metrics computed on each batch during validation. | | `batchfree_metrics` | `ModuleDict` | Metrics computed at the end of validation epoch. | | `t_loss_weights` | `Callable` | Function that weights loss at different time steps. | | `t_loss_probs` | `Callable` | Function that determines sampling probability of time steps. | | `N_noise_draws_per_sample` | `int` | Number of noise samples per data point. | | `samplewise_loss` | `SamplewiseDiffusionLoss` | Loss function for each sample. | | `batchwise_loss` | `Callable` | Factory-generated function that computes loss for a batch. | | `train_ts` | `Tensor` | Precomputed time steps for training. | | `train_ts_loss_weights` | `Tensor` | Precomputed weights for each time step. | | `train_ts_loss_probs` | `Tensor` | Precomputed sampling probabilities for each time step. | | `LOG_ON_STEP_TRAIN_LOSS` | `bool` | Whether to log training loss on each step. Default is True. | | `LOG_ON_EPOCH_TRAIN_LOSS` | `bool` | Whether to log training loss on each epoch. Default is True. | | `LOG_ON_PROGRESS_BAR_TRAIN_LOSS` | `bool` | Whether to display training loss on the progress bar. Default is True. | | `LOG_ON_STEP_BATCHWISE_METRICS` | `bool` | Whether to log batchwise metrics on each step. Default is False. | | `LOG_ON_EPOCH_BATCHWISE_METRICS` | `bool` | Whether to log batchwise metrics on each epoch. Default is True. | | `LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS` | `bool` | Whether to display batchwise metrics on the progress bar. Default is False. | | `LOG_ON_STEP_BATCHFREE_METRICS` | `bool` | Whether to log batchfree metrics on each step. Default is False. | | `LOG_ON_EPOCH_BATCHFREE_METRICS` | `bool` | Whether to log batchfree metrics on each epoch. Default is True. | | `LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS` | `bool` | Whether to display batchfree metrics on the progress bar. Default is False. | ### `LOG_ON_EPOCH_BATCHFREE_METRICS = True` ### `LOG_ON_EPOCH_BATCHWISE_METRICS = True` ### `LOG_ON_EPOCH_TRAIN_LOSS = True` ### `LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS = False` ### `LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS = False` ### `LOG_ON_PROGRESS_BAR_TRAIN_LOSS = True` ### `LOG_ON_STEP_BATCHFREE_METRICS = False` ### `LOG_ON_STEP_BATCHWISE_METRICS = False` ### `LOG_ON_STEP_TRAIN_LOSS = True` ### `N_noise_draws_per_sample = N_noise_draws_per_sample` ### `batchfree_metrics = nn.ModuleDict(batchfree_metrics)` ### `batchwise_loss = self.samplewise_loss.batchwise_loss_factory(N_noise_draws_per_sample=N_noise_draws_per_sample)` ### `batchwise_metrics = nn.ModuleDict(batchwise_metrics)` ### `diffusion_process = diffusion_process` ### `lr_scheduler = lr_scheduler` ### `net = net` ### `optimizer = optimizer` ### `samplewise_loss = SamplewiseDiffusionLoss(diffusion_process, vector_field_type)` ### `t_loss_probs = t_loss_probs` ### `t_loss_weights = t_loss_weights` ### `train_scheduler = train_scheduler` ### `vector_field_type = vector_field_type` ### `__init__(net, diffusion_process, train_scheduler, vector_field_type, optimizer, lr_scheduler, batchwise_metrics, batchfree_metrics, train_ts_hparams, t_loss_weights, t_loss_probs, N_noise_draws_per_sample)` Initialize the diffusion model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `net` | `Module` | Neural network that predicts the vector field. | *required* | | `diffusion_process` | `DiffusionProcess` | The diffusion process used for training. | *required* | | `train_scheduler` | `Scheduler` | Scheduler for generating training time steps. | *required* | | `vector_field_type` | `VectorFieldType` | Type of vector field the model predicts. | *required* | | `optimizer` | `Optimizer` | Optimizer for training the model. | *required* | | `lr_scheduler` | `LRScheduler` | Learning rate scheduler. | *required* | | `batchwise_metrics` | `Dict[str, Module]` | Metrics computed on each batch during validation. Each metric takes in (x, metadata, model) and returns a dictionary of metric (name, value) pairs. | *required* | | `batchfree_metrics` | `Dict[str, Module]` | Metrics computed at the end of validation epoch. Each metric takes in (model) and returns a dictionary of metric (name, value) pairs. | *required* | | `train_ts_hparams` | `Dict[str, Any]` | Parameters for the training time step scheduler. | *required* | | `t_loss_weights` | `Callable[[Tensor], Tensor]` | Function that weights loss at different time steps. | *required* | | `t_loss_probs` | `Callable[[Tensor], Tensor]` | Function that determines sampling probability of time steps. | *required* | | `N_noise_draws_per_sample` | `int` | Number of noise draws per data point. | *required* | ### `aggregate_loss(x)` Compute the loss for a batch of data with randomly sampled time steps. This method: 1. Samples time steps according to the training distribution 1. Computes the loss at those time steps Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x` | `Tensor` | Input data of shape (batch_size, \*data_dims). | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: Scalar loss value. | ### `configure_optimizers()` Configure optimizers and learning rate schedulers for training. This method is called by PyTorch Lightning to set up the optimization process. Returns: | Name | Type | Description | | --- | --- | --- | | `OptimizerLRScheduler` | `OptimizerLRScheduler` | Dictionary containing the optimizer and learning rate scheduler. | ### `forward(x, t)` Forward pass of the model. Passes the input through the neural network to predict the vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x` | `Tensor` | Input tensor of shape (batch_size, \*data_dims). | *required* | | `t` | `Tensor` | Time tensor of shape (batch_size,). | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: Predicted vector field of shape (batch_size, \*data_dims). | ### `loss(x, t, sample_weights)` Compute the loss for a batch of data at specified time steps. Uses the batchwise_loss function created from the SamplewiseDiffusionLoss factory to compute the loss for the batch. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x` | `Tensor` | Input data of shape (batch_size, \*data_dims). | *required* | | `t` | `Tensor` | Time steps of shape (batch_size,). | *required* | | `sample_weights` | `Tensor` | Weights for each sample of shape (batch_size,). | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: Scalar loss value. | ### `on_validation_epoch_end()` Perform operations at the end of a validation epoch. This method is called by PyTorch Lightning at the end of each validation epoch. It computes and logs any batch-free metrics that require the entire validation set. ### `precompute_train_schedule(train_ts_hparams)` Precompute time steps and their associated weights for training. This method generates the time steps used during training and computes the loss weights and sampling probabilities for each time step. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `train_ts_hparams` | `Dict[str, float]` | Parameters for the training time step scheduler. Typically includes t_min, t_max, and the number of steps L. | *required* | ### `training_step(batch, batch_idx)` Perform a single training step. This method is called by PyTorch Lightning during training. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `batch` | `Tensor` | Batch of data, typically a tuple (x, metadata). | *required* | | `batch_idx` | `int` | Index of the current batch. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: Loss value for the batch. | ### `validation_step(batch, batch_idx)` Perform a single validation step. This method is called by PyTorch Lightning during validation. It computes the loss and any batch-wise metrics. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `batch` | `Tensor` | Batch of data, typically a tuple (x, metadata). | *required* | | `batch_idx` | `int` | Index of the current batch. | *required* | Returns: | Type | Description | | --- | --- | | `Dict[str, Tensor]` | Dict\[str, torch.Tensor\]: Dictionary of metric values. | # Samplers This module contains functionality related to samplers. ## `DDMSampler` Bases: `Sampler` Class for sampling from diffusion models using the DDPM/DDIM sampler. ### `sample_step_deterministic_eps(eps, x, zs, idx, ts)` ### `sample_step_deterministic_score(score, x, zs, idx, ts)` ### `sample_step_deterministic_v(v, x, zs, idx, ts)` ### `sample_step_deterministic_x0(x0, x, zs, idx, ts)` ### `sample_step_stochastic_eps(eps, x, zs, idx, ts)` ### `sample_step_stochastic_score(score, x, zs, idx, ts)` ### `sample_step_stochastic_v(v, x, zs, idx, ts)` ### `sample_step_stochastic_x0(x0, x, zs, idx, ts)` ## `EulerMaruyamaSampler` Bases: `Sampler` ### `sample_step_deterministic_eps(eps, x, zs, idx, ts)` Perform one step of deterministic sampling using the eps vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `eps` | `VectorField` | The eps vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors (unused in deterministic sampling), of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_deterministic_score(score, x, zs, idx, ts)` Perform one step of deterministic sampling using the score vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `score` | `VectorField` | The score vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors (unused in deterministic sampling), of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_deterministic_v(v, x, zs, idx, ts)` Perform one step of deterministic sampling using the v vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `v` | `VectorField` | The velocity vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors (unused in deterministic sampling), of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_deterministic_x0(x0, x, zs, idx, ts)` Perform one step of deterministic sampling using the x0 vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x0` | `VectorField` | The x0 vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors (unused in deterministic sampling), of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_stochastic_eps(eps, x, zs, idx, ts)` Perform one step of stochastic sampling using the eps vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `eps` | `VectorField` | The eps vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors, of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_stochastic_score(score, x, zs, idx, ts)` Perform a stochastic sampling step using the score vector field. This implements the stochastic reverse SDE for score-based models using the Euler-Maruyama discretization method. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `score` | `VectorField` | The score vector field model | *required* | | `x` | `Tensor` | The current state tensor, of shape (N, \*D) where N is the batch size and D represents the data dimensions | *required* | | `zs` | `Tensor` | The noise tensors for stochastic sampling, of shape (L-1, N, \*D) where L is the number of time steps | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) where L is the number of time steps | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state tensor, of shape (N, \*D) | ### `sample_step_stochastic_v(v, x, zs, idx, ts)` Perform one step of stochastic sampling using the v vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `v` | `VectorField` | The velocity vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors, of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_stochastic_x0(x0, x, zs, idx, ts)` Perform one step of stochastic sampling using the x0 vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x0` | `VectorField` | The x0 vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors, of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ## `Sampler` Class for sampling from diffusion models using various vector field types. A Sampler combines a diffusion process and a scheduler to generate samples from a trained diffusion model. It handles both the forward process (adding noise) and the reverse process (denoising/sampling). The sampler supports different vector field types (SCORE, X0, EPS, V) and can perform both stochastic and deterministic sampling. Attributes: | Name | Type | Description | | --- | --- | --- | | `diffusion_process` | `DiffusionProcess` | The diffusion process defining the forward and reverse dynamics | | `is_stochastic` | `bool` | Whether the reverse process is stochastic or deterministic | ### `diffusion_process = diffusion_process` ### `is_stochastic = is_stochastic` ### `__init__(diffusion_process, is_stochastic)` Initialize a sampler with a diffusion process and sampling strategy. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `diffusion_process` | `DiffusionProcess` | The diffusion process to use for sampling | *required* | | `is_stochastic` | `bool` | Whether the reverse process should be stochastic | *required* | ### `get_sample_step_function(vector_field_type)` Get the appropriate sampling step function based on the vector field type. This method selects the correct sampling function based on the vector field type and whether sampling is stochastic or deterministic. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `vector_field_type` | `VectorFieldType` | The type of vector field being used (SCORE, X0, EPS, or V) | *required* | Returns: | Name | Type | Description | | --- | --- | --- | | `Callable` | `Callable[[VectorField, Tensor, Tensor, int, Tensor], Tensor]` | A function that performs one step of the sampling process with signature: (vector_field, x, zs, idx, ts) -> next_x where: - vector_field is the model - x is the current state tensor of shape (N, D) where N is the batch size and D represents the data dimensions - zs is the noise tensors of shape (L-1, N, D) where L is the number of time steps - idx is the current step index - ts is the time steps tensor of shape (L,) where L is the number of time steps - next_x is the next state tensor of shape (N, \*D) | ### `sample(vector_field, x_init, zs, ts)` Sample from the model using the reverse diffusion process. This method generates a sample by iteratively applying the appropriate sampling step function based on the vector field type. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `vector_field` | `VectorField` | The vector field model to use for sampling | *required* | | `x_init` | `Tensor` | The initial noisy tensor to start sampling from, of shape (N, \*D) where N is the batch size and D represents the data dimensions | *required* | | `zs` | `Tensor` | The noise tensors for stochastic sampling, of shape (L-1, N, \*D) where L is the number of time steps | *required* | | `ts` | `Tensor` | The time schedule for sampling, of shape (L,) where L is the number of time steps | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The generated sample, of shape (N, \*D) | ### `sample_step_deterministic_eps(eps, x, zs, idx, ts)` Perform one step of deterministic sampling using the eps vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `eps` | `VectorField` | The eps vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors (unused in deterministic sampling), of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_deterministic_score(score, x, zs, idx, ts)` Perform one step of deterministic sampling using the score vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `score` | `VectorField` | The score vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors (unused in deterministic sampling), of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_deterministic_v(v, x, zs, idx, ts)` Perform one step of deterministic sampling using the v vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `v` | `VectorField` | The velocity vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors (unused in deterministic sampling), of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_deterministic_x0(x0, x, zs, idx, ts)` Perform one step of deterministic sampling using the x0 vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x0` | `VectorField` | The x0 vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors (unused in deterministic sampling), of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_stochastic_eps(eps, x, zs, idx, ts)` Perform one step of stochastic sampling using the eps vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `eps` | `VectorField` | The eps vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors, of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_stochastic_score(score, x, zs, idx, ts)` Perform a stochastic sampling step using the score vector field. This method implements one step of the stochastic reverse process using the score function. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `score` | `VectorField` | The score vector field model | *required* | | `x` | `Tensor` | The current state tensor, of shape (N, \*D) where N is the batch size and D represents the data dimensions | *required* | | `zs` | `Tensor` | The noise tensors for stochastic sampling, of shape (L-1, N, \*D) where L is the number of time steps | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) where L is the number of time steps | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state tensor, of shape (N, \*D) | ### `sample_step_stochastic_v(v, x, zs, idx, ts)` Perform one step of stochastic sampling using the v vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `v` | `VectorField` | The velocity vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors, of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_step_stochastic_x0(x0, x, zs, idx, ts)` Perform one step of stochastic sampling using the x0 vector field. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x0` | `VectorField` | The x0 vector field model | *required* | | `x` | `Tensor` | The current state, of shape (N, \*D) | *required* | | `zs` | `Tensor` | The noise tensors, of shape (L-1, N, \*D) | *required* | | `idx` | `int` | The current step index | *required* | | `ts` | `Tensor` | The time steps tensor, of shape (L,) | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The next state after one sampling step, of shape (N, \*D) | ### `sample_trajectory(vector_field, x_init, zs, ts)` Sample a trajectory from the model using the reverse diffusion process. This method is similar to sample() but returns the entire trajectory of intermediate samples rather than just the final sample. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `vector_field` | `VectorField` | The vector field model to use for sampling | *required* | | `x_init` | `Tensor` | The initial noisy tensor to start sampling from, of shape (N, \*D) where N is the batch size and D represents the data dimensions | *required* | | `zs` | `Tensor` | The noise tensors for stochastic sampling, of shape (L-1, N, \*D) where L is the number of time steps | *required* | | `ts` | `Tensor` | The time schedule for sampling, of shape (L,) where L is the number of time steps | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The generated trajectory, of shape (L, N, \*D) where L is the number of time steps | # Schedulers This module contains functionality related to schedulers. ## `Scheduler` Base class for time step schedulers used in diffusion, denoising, and sampling. A scheduler determines the sequence of time steps used during the sampling process. Different scheduling strategies can affect the quality and efficiency of the generative process. The scheduler generates a sequence of time values, typically in the range [0, 1], which are used to control the noise level at each step of the sampling process. ### `__init__(**schedule_hparams)` Initialize the scheduler. This base implementation does not store any variables. Subclasses may override this method to initialize specific parameters. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `**schedule_hparams` | `Any` | Keyword arguments containing scheduler parameters. Not used in the base class but available for subclasses. | `{}` | ### `get_ts(**ts_hparams)` Generate the sequence of time steps. This is an abstract method that must be implemented by subclasses. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `**ts_hparams` | `Any` | Keyword arguments containing parameters for generating time steps. The specific parameters depend on the scheduler implementation. Typically includes: - t_min (float): The minimum time value - t_max (float): The maximum time value - L (int): The number of time steps to generate | `{}` | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: A tensor of shape (L,) containing the sequence of time steps in descending order, where L is the number of time steps. | Raises: | Type | Description | | --- | --- | | `NotImplementedError` | If the subclass does not implement this method. | ## `UniformScheduler` Bases: `Scheduler` A scheduler that generates uniformly spaced time steps. This scheduler creates a sequence of time steps that are uniformly distributed between a minimum and maximum time value. The time steps are returned in descending order (from t_max to t_min). This is the simplest scheduling strategy and is often used as a baseline. ### `__init__(**schedule_hparams)` Initialize the uniform scheduler. This implementation does not store any variables, following the base class design. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `**schedule_hparams` | `Any` | Keyword arguments containing scheduler parameters. Not used but passed to the parent class. | `{}` | ### `get_ts(**ts_hparams)` Generate uniformly spaced time steps. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `**ts_hparams` | `Any` | Keyword arguments containing: - t_min (float): The minimum time value, typically close to 0. - t_max (float): The maximum time value, typically close to 1. - L (int): The number of time steps to generate. | `{}` | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: A tensor of shape (L,) containing uniformly spaced time steps in descending order (from t_max to t_min), where L is the number of time steps specified in ts_hparams. | Raises: | Type | Description | | --- | --- | | `AssertionError` | If t_min or t_max are outside the range [0, 1] or if t_min > t_max. | # Utils This module contains functionality related to utils. ## `logdet_pd(A)` Computes the log-determinant of a positive-definite matrix A, broadcasting over A. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `A` | `Tensor` | A positive-definite matrix of shape (..., N, N) where ... represents any number of batch dimensions. | *required* | Returns: | Name | Type | Description | | --- | --- | --- | | `logdet_A` | `Tensor` | The log-determinant of A of shape (...) with the same batch dimensions as A. | ## `pad_shape_back(x, target_shape)` Pads the back of a tensor with singleton dimensions until it can broadcast with target_shape. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x` | `Tensor` | A tensor of any shape, say (P, Q, R, S). | *required* | | `target_shape` | `Size` | A shape to which x can broadcast, say (P, Q, R, S, T, U, V). | *required* | Returns: | Name | Type | Description | | --- | --- | --- | | `x_padded` | `Tensor` | The tensor x reshaped to be broadcastable with target_shape, say (P, Q, R, S, 1, 1, 1). The returned tensor has shape (\*x.shape, 1, ..., 1) with enough trailing 1s to match the dimensionality of target_shape. | Note This function does not use any additional memory, returning a different view of the same underlying data. ## `pad_shape_front(x, target_shape)` Pads the front of a tensor with singleton dimensions until it can broadcast with target_shape. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x` | `Tensor` | A tensor of any shape, say (P, Q, R, S). | *required* | | `target_shape` | `Size` | A shape to which x can broadcast, say (M, N, O, P, Q, R, S). | *required* | Returns: | Name | Type | Description | | --- | --- | --- | | `x_padded` | `Tensor` | The tensor x reshaped to be broadcastable with target_shape, say (1, 1, 1, P, Q, R, S). The returned tensor has shape (1, ..., 1, \*x.shape) with enough leading 1s to match the dimensionality of target_shape. | Note This function does not use any additional memory, returning a different view of the same underlying data. ## `scalar_derivative(f)` Computes the scalar derivative of a function f: R -> R. Returns a function f_prime: R -> R that computes the derivative of f at a given point, and is broadcastable with the same broadcast rules as f. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `f` | `Callable[[Tensor], Tensor]` | A function whose input is a scalar (0-dimensional Pytorch tensor) and whose output is a scalar, that can be broadcasted to a tensor of any shape. | *required* | Returns: | Name | Type | Description | | --- | --- | --- | | `f_prime` | `Callable[[Tensor], Tensor]` | A function that computes the derivative of f at a given point, and is broadcastable with the same broadcast rules as f. For input of shape (N,), output will be of shape (N,). | ## `sqrt_psd(A)` Computes the matrix square root of a positive-semidefinite matrix A, broadcasting over A. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `A` | `Tensor` | A positive-semidefinite matrix of shape (..., N, N) where ... represents any number of batch dimensions. | *required* | Returns: | Name | Type | Description | | --- | --- | --- | | `sqrt_A` | `Tensor` | The matrix square root of A of shape (..., N, N) with the same shape as A. | # Vector Fields This module contains functionality related to vector fields. ## `VectorField` A wrapper around a function (x, t) -> f(x, t) which provides some extra data, namely the type of vector field the function f represents. This class encapsulates a vector field function and its type, allowing for consistent handling of different vector field representations in diffusion models. Attributes: | Name | Type | Description | | --- | --- | --- | | `f` | `Callable` | A function that takes tensors x of shape (N, D) and t of shape (N,) and returns a tensor of shape (N, D). | | `vector_field_type` | `VectorFieldType` | The type of vector field the function represents. | ### `f = f` ### `vector_field_type = vector_field_type` ### `__call__(x, t)` Call the wrapped vector field function. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x` | `Tensor` | Input tensor of shape (N, \*D) where N is the batch size and D represents the data dimensions. | *required* | | `t` | `Tensor` | Time parameter tensor of shape (N,). | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: Output of the vector field function, of shape (N, \*D). | ### `__init__(f, vector_field_type)` Initialize a vector field wrapper. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `f` | `Callable[[Tensor, Tensor], Tensor]` | A function that takes tensors x of shape (N, D) and t of shape (N,) and returns a tensor of shape (N, D). | *required* | | `vector_field_type` | `VectorFieldType` | The type of vector field the function represents (SCORE, X0, EPS, or V). | *required* | ## `VectorFieldType` Bases: `Enum` ### `EPS = enum.auto()` ### `SCORE = enum.auto()` ### `V = enum.auto()` ### `X0 = enum.auto()` ## `convert_vector_field_type(x, fx, alpha, sigma, alpha_prime, sigma_prime, in_type, out_type)` Converts the output of a vector field from one type to another. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `x` | `Tensor` | A tensor of shape (N, \*D), where N is the batch size and D is the shape of the data (e.g., (C, H, W) for images, (D,) for vectors, or (N, D) for token sequences). | *required* | | `fx` | `Tensor` | The output of the vector field f, of shape (N, \*D). | *required* | | `alpha` | `Tensor` | A tensor of shape (N,) representing the scale parameter. | *required* | | `sigma` | `Tensor` | A tensor of shape (N,) representing the noise level parameter. | *required* | | `alpha_prime` | `Tensor` | A tensor of shape (N,) representing the scale derivative parameter. | *required* | | `sigma_prime` | `Tensor` | A tensor of shape (N,) representing the noise level derivative parameter. | *required* | | `in_type` | `VectorFieldType` | The type of the input vector field (e.g. Score, X0, Eps, V). | *required* | | `out_type` | `VectorFieldType` | The type of the output vector field. | *required* | Returns: | Type | Description | | --- | --- | | `Tensor` | torch.Tensor: The converted output of the vector field, of shape (N, \*D). |