# Plaintext API Reference for diffusionlab ## Module: `diffusionlab` *No module docstring.* ## Module: `diffusionlab.distributions` *No module docstring.* ## Module: `diffusionlab.distributions.base` *No module docstring.* ### Class: `Distribution` ``` Base class for all distributions. This class should be subclassed by other distributions when you want to use ground truth scores, denoisers, noise predictors, or velocity estimators. Each distribution implementation provides functions to sample from it and compute various vector fields related to a diffusion process, such as denoising (``x0``), noise prediction (``eps``), velocity estimation (``v``), and score estimation (``score``). Attributes: dist_params (``Dict[str, Array]``): Dictionary containing distribution parameters as JAX arrays. Shapes depend on the specific distribution. dist_hparams (``Dict[str, Any]``): Dictionary containing distribution hyperparameters (non-array values). ``` #### Method: `eps(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Predict the noise component ``ε`` corresponding to the noisy state ``x_t`` at time ``t``, given the ``diffusion_process``. Args: x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step. diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[*data_dims]``: The predicted noise ``ε``. ``` #### Method: `get_vector_field(self, vector_field_type: diffusionlab.vector_fields.VectorFieldType) -> Callable[[jax.Array, jax.Array, diffusionlab.dynamics.DiffusionProcess], jax.Array]` ``` Get the vector field function of a given type associated with this distribution. Args: vector_field_type (``VectorFieldType``): The type of vector field to retrieve (e.g., ``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``). Returns: ``Callable[[Array[*data_dims], Array[], DiffusionProcess], Array[*data_dims]]``: The requested vector field function. It takes the current state ``x_t`` (``Array[*data_dims]``), time ``t`` (``Array[]``), and the ``diffusion_process`` as input and returns the corresponding vector field value (``Array[*data_dims]``). ``` #### Method: `sample(self, key: jax.Array, num_samples: int) -> Tuple[jax.Array, Any]` ``` Sample from the distribution. Args: key (``Array``): The JAX PRNG key to use for sampling. num_samples (``int``): The number of samples to draw. Returns: ``Tuple[Array[num_samples, *data_dims], Any]``: A tuple containing the samples and any additional information. ``` #### Method: `score(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Compute the score function (``∇_x log p_t(x)``) of the distribution at time ``t``, given the noisy state ``x_t`` and the ``diffusion_process``. Args: x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step. diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[*data_dims]``: The score of the distribution at ``(x_t, t)``. ``` #### Method: `v(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Compute the velocity field ``v(x_t, t)`` corresponding to the noisy state ``x_t`` at time ``t``, given the ``diffusion_process``. Args: x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step. diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[*data_dims]``: The computed velocity field ``v``. ``` #### Method: `x0(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Predict the initial state ``x0`` (denoised sample) from the noisy state ``x_t`` at time ``t``, given the ``diffusion_process``. Args: x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step. diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[*data_dims]``: The predicted initial state ``x0``. ``` ## Module: `diffusionlab.distributions.empirical` *No module docstring.* ### Class: `EmpiricalDistribution` ``` An empirical distribution, i.e., the uniform distribution over a dataset. The probability measure is defined as: ``μ(A) = (1/N) * sum_{i=1}^{num_samples} delta(x_i in A)`` where ``x_i`` is the ith data point in the dataset, and ``N`` is the number of data points. This class provides methods for sampling from the empirical distribution and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. Attributes: dist_params (``Dict[str, Array]``): Dictionary containing distribution parameters (currently unused). dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters. It may contain the following keys: - ``labeled_data`` (``Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]]``): An iterable of data whose elements (samples) are tuples of (data batch, label batch). The label batch can be ``None`` if the data is unlabelled. ``` #### Method: `eps(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the noise field ``eps(x_t, t)`` for an empirical distribution w.r.t. a given diffusion process. Args: x_t (``Array[*data_dims]``): The input tensor. t (``Array[]``): The time tensor. diffusion_process (``DiffusionProcess``): The diffusion process. Returns: ``Array[*data_dims]``: The noise field at ``(x_t, t)``. ``` #### Method: `sample(self, key: jax.Array, num_samples: int) -> Union[Tuple[jax.Array, jax.Array], Tuple[jax.Array, NoneType]]` ``` Sample from the empirical distribution using reservoir sampling. Assumes all batches in ``labeled_data`` are consistent: either all have labels (``Array``) or none have labels (``None``). Args: key (``Array``): The JAX PRNG key to use for sampling. num_samples (``int``): The number of samples to draw. Returns: ``Tuple[Array[num_samples, *data_dims], Array[num_samples, *label_dims]] | Tuple[Array[num_samples, *data_dims], None]``: A tuple ``(samples, labels)`` containing the samples and corresponding labels (stacked into an ``Array``), or ``(samples, None)`` if the data is unlabelled. ``` #### Method: `score(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the score function (``∇_x log p_t(x)``) of the empirical distribution at time ``t``, given the noisy state ``x_t`` and the diffusion process. Args: x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. t (``Array[]``): The time tensor. diffusion_process (``DiffusionProcess``): The diffusion process. Returns: ``Array[*data_dims]``: The score of the empirical distribution at ``(x_t, t)``. ``` #### Method: `v(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the velocity field ``v(x_t, t)`` for an empirical distribution w.r.t. a given diffusion process. Args: x_t (``Array[*data_dims]``): The input tensor. t (``Array[]``): The time tensor. diffusion_process (``DiffusionProcess``): The diffusion process. Returns: ``Array[*data_dims]``: The velocity field at ``(x_t, t)``. ``` #### Method: `x0(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the denoiser ``E[x_0 | x_t]`` for an empirical distribution w.r.t. a given diffusion process. This method computes the denoiser by performing a weighted average of the dataset samples, where the weights are determined by the likelihood of ``x_t`` given each sample. Arguments: x_t (``Array[*data_dims]``): The input tensor. t (``Array[]``): The time tensor. diffusion_process (``DiffusionProcess``): The diffusion process. Returns: ``Array[*data_dims]``: The prediction of ``x_0``. ``` ## Module: `diffusionlab.distributions.gmm` *No module docstring.* ## Module: `diffusionlab.distributions.gmm.gmm` *No module docstring.* ### Class: `GMM` ``` Implements a Gaussian Mixture Model (GMM) distribution. The probability measure is given by: ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], covs[i])`` This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. Attributes: dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters. - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. - ``covs`` (``Array[num_components, data_dim, data_dim]``): The covariance matrices of the GMM components. - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). ``` #### Method: `eps(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the noise prediction ``ε`` for the GMM distribution. This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. ``` #### Method: `sample(self, key: jax.Array, num_samples: int) -> Tuple[jax.Array, jax.Array]` ``` Draws samples from the GMM distribution. Args: key (``Array``): JAX PRNG key for random sampling. num_samples (``int``): The total number of samples to generate. Returns: ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. ``` #### Method: `score(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the score vector field ``(∇_x log p_t(x_t))`` for the GMM distribution. This is calculated with respect to the perturbed distribution ``p_t`` induced by the ``diffusion_process`` at time ``t``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. ``` #### Method: `v(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the velocity vector field ``v`` for the GMM distribution. This relates to the conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. ``` #### Method: `x0(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the GMM distribution. This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. ``` ### Function: `gmm_x0(x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess, means: jax.Array, covs: jax.Array, priors: jax.Array) -> jax.Array` ``` Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a GMM. This implements the closed-form solution for the conditional expectation ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution defined by ``means``, ``covs``, and ``priors``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. means (``Array[num_components, data_dim]``): Means of the GMM components. covs (``Array[num_components, data_dim, data_dim]``): Covariances of the GMM components. priors (``Array[num_components]``): Mixture weights of the GMM components. Returns: ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. ``` ## Module: `diffusionlab.distributions.gmm.iso_gmm` *No module docstring.* ### Class: `IsoGMM` ``` Implements an isotropic Gaussian Mixture Model (GMM) distribution. The probability measure is given by: ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variances[i] * I)`` This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. Attributes: dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters. - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. - ``variances`` (``Array[num_components]``): The variances of the GMM components. - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). ``` #### Method: `eps(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the noise prediction ``ε`` for the isotropic GMM distribution. This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. ``` #### Method: `sample(self, key: jax.Array, num_samples: int) -> Tuple[jax.Array, jax.Array]` ``` Draws samples from the isotropic GMM distribution. Args: key (``Array``): JAX PRNG key for random sampling. num_samples (``int``): The total number of samples to generate. Returns: ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. ``` #### Method: `score(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the score vector field ``∇_x log p_t(x_t)`` for the isotropic GMM distribution. This is calculated with respect to the perturbed distribution ``p_t`` induced by the ``diffusion_process`` at time ``t``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. ``` #### Method: `v(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the velocity vector field ``v`` for the isotropic GMM distribution. This is conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. ``` #### Method: `x0(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the isotropic GMM distribution. This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. ``` ### Function: `iso_gmm_x0(x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess, means: jax.Array, variances: jax.Array, priors: jax.Array) -> jax.Array` ``` Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a GMM. This implements the closed-form solution for the conditional expectation ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution defined by ``means``, ``covs``, and ``priors``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. means (``Array[num_components, data_dim]``): Means of the GMM components. variances (``Array[num_components]``): Covariances of the GMM components. priors (``Array[num_components]``): Mixture weights of the GMM components. Returns: ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. ``` ## Module: `diffusionlab.distributions.gmm.iso_hom_gmm` *No module docstring.* ### Class: `IsoHomGMM` ``` Implements an isotropic homoscedastic Gaussian Mixture Model (GMM) distribution. The probability measure is given by: ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variance * I)`` This class provides methods for sampling from the isotropic homoscedastic GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. Attributes: dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters. - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. - ``variance`` (``Array[]``): The variance of the GMM components. - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). ``` #### Method: `eps(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the noise prediction ``ε`` for the isotropic homoscedastic GMM distribution. This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. ``` #### Method: `sample(self, key: jax.Array, num_samples: int) -> Tuple[jax.Array, jax.Array]` ``` Draws samples from the isotropic homoscedastic GMM distribution. Args: key (``Array``): JAX PRNG key for random sampling. num_samples (``int``): The total number of samples to generate. Returns: ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. ``` #### Method: `score(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the score vector field ``∇_x log p_t(x_t)`` for the isotropic homoscedastic GMM distribution. This is calculated with respect to the perturbed distribution p_t induced by the `diffusion_process` at time `t`. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. ``` #### Method: `v(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the velocity vector field ``v`` for the isotropic homoscedastic GMM distribution. This is conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. ``` #### Method: `x0(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the isotropic homoscedastic GMM distribution. This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. ``` ### Function: `iso_hom_gmm_x0(x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess, means: jax.Array, variance: jax.Array, priors: jax.Array) -> jax.Array` ``` Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for an isotropic homoscedastic GMM. This implements the closed-form solution for the conditional expectation ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution defined by ``means``, ``variance``, and ``priors``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. means (``Array[num_components, data_dim]``): Means of the GMM components. variance (``Array[]``): Covariance of the GMM components. priors (``Array[num_components]``): Mixture weights of the GMM components. Returns: ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. ``` ## Module: `diffusionlab.distributions.gmm.low_rank_gmm` *No module docstring.* ### Class: `LowRankGMM` ``` Implements a low-rank Gaussian Mixture Model (GMM) distribution. The probability measure is given by: ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], cov_factors[i] @ cov_factors[i].T)`` This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. Attributes: dist_params (``Dict[str, Array]``): Dictionary containing the core low-rank GMM parameters. - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. - ``cov_factors`` (``Array[num_components, data_dim, rank]``): The low-rank covariance matrix factors of the GMM components. - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). ``` #### Method: `eps(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the noise prediction ε for the low-rank GMM distribution. This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. ``` #### Method: `sample(self, key: jax.Array, num_samples: int) -> Tuple[jax.Array, jax.Array]` ``` Draws samples from the low-rank GMM distribution. Args: key (``Array``): JAX PRNG key for random sampling. num_samples (``int``): The total number of samples to generate. Returns: ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. ``` #### Method: `score(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the score vector field ``∇_x log p_t(x_t)`` for the low-rank GMM distribution. This is calculated with respect to the perturbed distribution ``p_t`` induced by the ``diffusion_process`` at time ``t``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. ``` #### Method: `v(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the velocity vector field ``v`` for the low-rank GMM distribution. This is the conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. ``` #### Method: `x0(self, x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess) -> jax.Array` ``` Computes the denoised prediction x0 = E[x_0 | x_t] for the low-rank GMM distribution. This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. ``` ### Function: `low_rank_gmm_x0(x_t: jax.Array, t: jax.Array, diffusion_process: diffusionlab.dynamics.DiffusionProcess, means: jax.Array, cov_factors: jax.Array, priors: jax.Array) -> jax.Array` ``` Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a low-rank GMM. This implements the closed-form solution for the conditional expectation ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the low-rank GMM distribution defined by ``means``, ``cov_factors``, and ``priors``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. means (``Array[num_components, data_dim]``): Means of the GMM components. cov_factors (``Array[num_components, data_dim, rank]``): Low-rank covariance matrices of the GMM components. priors (``Array[num_components]``): Mixture weights of the GMM components. Returns: ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. ``` ## Module: `diffusionlab.distributions.gmm.utils` *No module docstring.* ### Function: `create_gmm_vector_field_fns(x0_fn: Callable[[jax.Array, jax.Array, diffusionlab.dynamics.DiffusionProcess, jax.Array, jax.Array, jax.Array], jax.Array]) -> Tuple[Callable[[jax.Array, jax.Array, diffusionlab.dynamics.DiffusionProcess, jax.Array, jax.Array, jax.Array], jax.Array], Callable[[jax.Array, jax.Array, diffusionlab.dynamics.DiffusionProcess, jax.Array, jax.Array, jax.Array], jax.Array], Callable[[jax.Array, jax.Array, diffusionlab.dynamics.DiffusionProcess, jax.Array, jax.Array, jax.Array], jax.Array]]` ``` Factory to create eps, score, and v functions from a given x0 function. Args: x0_fn: The specific x0 calculation function (e.g., ``gmm_x0``, ``iso_gmm_x0``). It must accept ``(x_t, t, diffusion_process, means, specific_cov, priors)``. Returns: ``Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]]``: A tuple containing the generated ``(eps_fn, score_fn, v_fn)``. These functions will have the same signature as ``x0_fn``, accepting ``(x_t, t, diffusion_process, means, specific_cov, priors)``. ``` ## Module: `diffusionlab.dynamics` *No module docstring.* ### Class: `DiffusionProcess` ``` Base class for implementing various diffusion processes. A diffusion process defines how data evolves over time when noise is added according to specific dynamics operating on scalar time inputs. This class provides a framework to implement diffusion processes based on a schedule defined by ``α(t)`` and ``σ(t)``. The diffusion is parameterized by two scalar functions of scalar time ``t``: - ``α(t)``: Controls how much of the original signal is preserved at time ``t``. - ``σ(t)``: Controls how much noise is added at time ``t``. The forward process for a single data point ``x_0`` is defined as: ``x_t = α(t) * x_0 + σ(t) * ε`` where: - ``x_0`` is the original data (``Array[*data_dims]``) - ``x_t`` is the noised data at time ``t`` (``Array[*data_dims]``) - ``ε`` is random noise sampled from a standard Gaussian distribution (``Array[*data_dims]``) - ``t`` is the scalar diffusion time parameter (``Array[]``) Attributes: alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. ``` #### Method: `forward(self, x: jax.Array, t: jax.Array, eps: jax.Array) -> jax.Array` ``` Applies the forward diffusion process to a data tensor ``x`` at time ``t`` using noise ``ε``. Computes ``x_t = α(t) * x + σ(t) * ε``. Args: x (``Array[*data_dims]``): The input data tensor ``x_0``. t (``Array[]``): The scalar time parameter ``t``. eps (``Array[*data_dims]``): The Gaussian noise tensor ``ε``, matching the shape of ``x``. Returns: ``Array[*data_dims]``: The noised data tensor ``x_t`` at time ``t``. ``` ### Class: `FlowMatchingProcess` ``` Implements a diffusion process based on Flow Matching principles. This process defines dynamics that linearly interpolate between the data distribution at ``t=0`` and a noise distribution (standard Gaussian) at ``t=1``. Uses the following scalar dynamics: - ``α(t) = 1 - t`` - ``σ(t) = t`` Forward process: ``x_t = (1 - t) * x_0 + t * ε``. Attributes: alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to ``1 - t``. sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Set to ``t``. alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to ``-1``. sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. Set to ``1``. ``` ### Class: `VarianceExplodingProcess` ``` Implements a Variance Exploding (VE) diffusion process. In this process, the signal component is constant (``α(t) = 1``), while the noise component increases over time according to the provided ``σ(t)`` function. The variance of the noised data ``x_t`` explodes as ``t`` increases. Forward process: ``x_t = x_0 + σ(t) * ε``. This process uses: - ``α(t) = 1`` - ``σ(t) =`` Provided by the user Attributes: alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to 1. sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Provided by the user. alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to 0. sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. ``` ### Class: `VariancePreservingProcess` ``` Implements a Variance Preserving (VP) diffusion process, often used in DDPMs. This process maintains the variance of the noised data ``x_t`` close to 1 (assuming ``x_0`` and ``ε`` have unit variance) throughout the diffusion by scaling the signal and noise components appropriately. Uses the following scalar dynamics: - ``α(t) = sqrt(1 - t²)`` - ``σ(t) = t`` Forward process: ``x_t = sqrt(1 - t²) * x_0 + t * ε``. Attributes: alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to ``sqrt(1 - t²)``. sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Set to ``t``. alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to ``-t / sqrt(1 - t²)``. sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. Set to ``1``. ``` ## Module: `diffusionlab.losses` *No module docstring.* ### Class: `DiffusionLoss` ``` Loss function for training diffusion models. This dataclass implements various loss functions for diffusion models based on the specified target type. The loss is computed as the mean squared error between the model's prediction and the target, which depends on the chosen vector field type. The loss supports different target types: - ``VectorFieldType.X0``: Learn to predict the original clean data x_0 - ``VectorFieldType.EPS``: Learn to predict the noise component eps - ``VectorFieldType.V``: Learn to predict the velocity field v - ``VectorFieldType.SCORE``: Not directly supported (raises ValueError) Attributes: diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics vector_field_type (``VectorFieldType``): The type of target to learn to estimate via minimizing the loss function. num_noise_draws_per_sample (``int``): The number of noise draws per sample to use for the batchwise loss. target (``Callable[[Array, Array, Array, Array, Array], Array]``): Function that computes the target based on the specified target_type. Signature: ``(x_t: Array[*data_dims], f_x_t: Array[*data_dims], x_0: Array[*data_dims], eps: Array[*data_dims], t: Array[]) -> Array[*data_dims]`` ``` #### Method: `loss(self, key: jax.Array, vector_field: Callable[[jax.Array, jax.Array], jax.Array], x_0: jax.Array, t: jax.Array) -> jax.Array` ``` Compute the average loss over multiple noise draws for a single data point and time. This method estimates the expected loss at a given time ``t`` for a clean data sample ``x_0``. It does this by drawing ``num_noise_draws_per_sample`` noise vectors (``eps``), generating the corresponding noisy samples ``x_t`` using the ``diffusion_process``, predicting the target quantity ``f_x_t`` using the provided ``vector_field`` (vmapped internally), and then calculating the ``prediction_loss`` for each noise sample. The final loss is the average over these samples. Args: key (``Array``): The PRNG key for noise generation. vector_field (``Callable[[Array, Array], Array]``): The vector field function that takes a single noisy data sample ``x_t`` and its corresponding time ``t``, and returns the model's prediction ``f_x_t``. This function will be vmapped internally over the batch dimension created by ``num_noise_draws_per_sample``. Signature: ``(x_t: Array[*data_dims], t: Array[]) -> Array[*data_dims]``. x_0 (``Array[*data_dims]``): The original clean data sample. t (``Array[]``): The scalar time parameter. Returns: ``Array[]``: The scalar loss value, averaged over ``num_noise_draws_per_sample`` noise instances. ``` #### Method: `prediction_loss(self, x_t: jax.Array, f_x_t: jax.Array, x_0: jax.Array, eps: jax.Array, t: jax.Array) -> jax.Array` ``` Compute the loss given a prediction and inputs/targets. This method calculates the mean squared error between the model's prediction (``f_x_t``) and the target value determined by the target_type (``self.target``). Args: x_t (``Array[*data_dims]``): The noised data at time ``t``. f_x_t (``Array[*data_dims]``): The model's prediction at time ``t``. x_0 (``Array[*data_dims]``): The original clean data. eps (``Array[*data_dims]``): The noise used to generate ``x_t``. t (``Array[]``): The scalar time parameter. Returns: ``Array[]``: The scalar loss value for the given sample. ``` ## Module: `diffusionlab.samplers` *No module docstring.* ### Class: `DDMSampler` ``` Class for sampling from diffusion models using the Denoising Diffusion Probabilistic Models (DDPM) or Denoising Diffusion Implicit Models (DDIM) sampling strategy. This sampler first converts any given vector field type (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) provided by ``vector_field`` into an equivalent x0 prediction using the ``convert_vector_field_type`` utility. Then, it applies the DDPM (if ``use_stochastic_sampler`` is ``True``) or DDIM (if ``use_stochastic_sampler`` is ``False``) update rule based on this x0 prediction. Attributes: diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics. vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field. vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``. use_stochastic_sampler (``bool``): If ``True``, uses DDPM (stochastic); otherwise, uses DDIM (deterministic). sample_step (``Callable[[int, Array, Array, Array], Array]``): The DDPM or DDIM step function. ``` #### Method: `get_sample_step_function(self) -> Callable[[int, jax.Array, jax.Array, jax.Array], jax.Array]` ``` Get the appropriate DDPM/DDIM sampling step function based on stochasticity. Returns: ``Callable[[int, Array, Array, Array], Array]``: The DDPM (stochastic) or DDIM (deterministic) step function, which has signature: ``(idx: int, x: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]`` ``` ### Class: `EulerMaruyamaSampler` ``` Class for sampling from diffusion models using the first-order Euler-Maruyama integrator for the reverse process SDE/ODE. This sampler implements the step function based on the Euler-Maruyama discretization of the reverse SDE (if ``use_stochastic_sampler`` is True) or the corresponding probability flow ODE (if ``use_stochastic_sampler`` is False). It supports all vector field types (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``). Attributes: diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics. vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field. Takes the current state ``x_t`` and time ``t`` as input. vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``. use_stochastic_sampler (``bool``): Whether to use a stochastic or deterministic reverse process. sample_step (``Callable[[int, Array, Array, Array], Array]``): The specific function used to perform one sampling step. Takes step index ``idx``, current state ``x_t``, noise array ``zs``, and time schedule ``ts`` as input. Set during initialization based on the sampler type and ``use_stochastic_sampler``. ``` #### Method: `get_sample_step_function(self) -> Callable[[int, jax.Array, jax.Array, jax.Array], jax.Array]` ``` Get the appropriate Euler-Maruyama sampling step function based on the vector field type and stochasticity. Returns: Callable[[int, Array, Array, Array], Array]: The specific Euler-Maruyama step function to use. Signature: ``(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]`` ``` ### Class: `Sampler` ``` Base class for sampling from diffusion models using various vector field types. A Sampler combines a diffusion process, a vector field prediction function, and a scheduler to generate samples from a trained diffusion model using the reverse process (denoising/sampling). The sampler supports different vector field types (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) and can perform both stochastic and deterministic sampling based on the subclass implementation and the `use_stochastic_sampler`` flag. Attributes: diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics. vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field. Takes the current state ``x_t`` and time ``t`` as input. vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``. use_stochastic_sampler (``bool``): Whether to use a stochastic or deterministic reverse process. sample_step (``Callable[[int, Array, Array, Array], Array]``): The specific function used to perform one sampling step. Takes step index ``idx``, current state ``x_t``, noise array ``zs``, and time schedule ``ts`` as input. Set during initialization based on the sampler type and ``use_stochastic_sampler``. ``` #### Method: `get_sample_step_function(self) -> Callable[[int, jax.Array, jax.Array, jax.Array], jax.Array]` ``` Abstract method to get the appropriate sampling step function. Subclasses must implement this method to return the specific function used for performing one step of the reverse process, based on the sampler's implementation details (e.g., integrator type) and the ``use_stochastic_sampler`` flag. Returns: ``Callable[[int, Array, Array, Array], Array]``: The sampling step function, which has signature: ``(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]`` ``` #### Method: `sample(self, x_init: jax.Array, zs: jax.Array, ts: jax.Array) -> jax.Array` ``` Sample from the model using the reverse diffusion process. This method generates a final sample by iteratively applying the ``sample_step`` function, starting from an initial state ``x_init`` and using the provided noise ``zs`` and time schedule ``ts``. Args: x_init (``Array[*data_dims]``): The initial noisy tensor from which to initialize sampling (typically sampled from the prior distribution at ``ts[0]``). zs (``Array[num_steps, *data_dims]``): The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers. ts (``Array[num_steps+1]``): The time schedule for sampling. A sorted decreasing array of times from ``t_max`` to ``t_min``. Returns: ``Array[*data_dims]``: The generated sample at the final time ``ts[-1]``. ``` #### Method: `sample_trajectory(self, x_init: jax.Array, zs: jax.Array, ts: jax.Array) -> jax.Array` ``` Sample a trajectory from the model using the reverse diffusion process. This method generates the entire trajectory of intermediate samples by iteratively applying the ``sample_step`` function. Args: x_init (``Array[*data_dims]``): The initial noisy tensor from which to start sampling (at time ``ts[0]``). zs (``Array[num_steps, *data_dims]``): The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers. ts (``Array[num_steps+1]``): The time schedule for sampling. A sorted decreasing array of times from ``t_max`` to ``t_min``. Returns: ``Array[num_steps+1, *data_dims]``: The complete generated trajectory including the initial state ``x_init``. ``` ## Module: `diffusionlab.schedulers` *No module docstring.* ### Class: `Scheduler` ``` Base class for time step schedulers used in diffusion, denoising, and sampling. Allows for extensible scheduler implementations where subclasses can define their own initialization and time step generation parameters via **kwargs. ``` #### Method: `get_ts(self, **ts_hparams: Any) -> jax.Array` ``` Generate the sequence of time steps. This is an abstract method that must be implemented by subclasses. Subclasses should define the specific keyword arguments they expect within ``**ts_hparams``. Args: **ts_hparams (``Dict[str, Any]``): Keyword arguments containing parameters for generating time steps. Returns: ``Array``: A tensor containing the sequence of time steps in descending order. Raises: NotImplementedError: If the subclass does not implement this method. KeyError: If a required parameter is missing in ``**ts_hparams`` (in subclass). ``` ### Class: `UniformScheduler` ``` A scheduler that generates uniformly spaced time steps. Requires ``t_min``, ``t_max``, and ``num_steps`` to be passed to the ``get_ts`` method via keyword arguments. The number of points generated will be ``num_steps + 1``. ``` #### Method: `get_ts(self, **ts_hparams: Any) -> jax.Array` ``` Generate uniformly spaced time steps. Args: **ts_hparams (``Dict[str, Any]``): Keyword arguments must contain - ``t_min`` (``float``): The minimum time value, typically close to 0. - ``t_max`` (``float``): The maximum time value, typically close to 1. - ``num_steps`` (``int``): The number of diffusion steps. The function will generate ``num_steps + 1`` time points. Returns: ``Array[num_steps+1]``: A JAX array containing uniformly spaced time steps in descending order (from ``t_max`` to ``t_min``). Raises: KeyError: If ``t_min``, ``t_max``, or ``num_steps`` is not found in ``ts_hparams``. AssertionError: If ``t_min``/``t_max`` constraints are violated or ``num_steps`` < 1. ``` ## Module: `diffusionlab.vector_fields` *No module docstring.* ### Class: `VectorFieldType` ``` Enum representing the type of a vector field. A vector field is a function that takes in ``x_t`` (``Array[*data_dims]``) and ``t`` (``Array[]``) and returns a vector of the same shape as ``x_t`` (``Array[*data_dims]``). DiffusionLab supports the following vector field types: - ``VectorFieldType.SCORE``: The score of the distribution. - ``VectorFieldType.X0``: The denoised state. - ``VectorFieldType.EPS``: The noise. - ``VectorFieldType.V``: The velocity field. ``` ### Function: `convert_vector_field_type(x: jax.Array, f_x: jax.Array, alpha: jax.Array, sigma: jax.Array, alpha_prime: jax.Array, sigma_prime: jax.Array, in_type: diffusionlab.vector_fields.VectorFieldType, out_type: diffusionlab.vector_fields.VectorFieldType) -> jax.Array` ``` Converts the output of a vector field from one type to another. Arguments: x (``Array[*data_dims]``): The input tensor. f_x (``Array[*data_dims]``): The output of the vector field f evaluated at x. alpha (``Array[]``): A scalar representing the scale parameter. sigma (``Array[]``): A scalar representing the noise level parameter. alpha_prime (``Array[]``): A scalar representing the scale derivative parameter. sigma_prime (``Array[]``): A scalar representing the noise level derivative parameter. in_type (``VectorFieldType``): The type of the input vector field (e.g. ``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``). out_type (``VectorFieldType``): The type of the output vector field. Returns: ``Array[*data_dims]``: The converted output of the vector field ```