Skip to content

jelli.utils.jax_helpers

batched_outer_ravel(arr)

Compute the outer product for each 1D array in a batch and return them as raveled 1D arrays.

Parameters:

Name Type Description Default
arr ndarray

A JAX array of shape (..., N), where ... represents any number of batch dimensions

required

Returns:

Type Description
ndarray

A JAX array of shape (..., N*N), where each slice along the batch dimensions corresponds to the raveled outer product of the respective input array.

Source code in jelli/utils/jax_helpers.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def batched_outer_ravel(arr):
    '''
    Compute the outer product for each 1D array in a batch and return them as raveled 1D arrays.

    Parameters
    ----------
    arr : jnp.ndarray
        A JAX array of shape (..., N), where ... represents any number of batch dimensions

    Returns
    -------
    jnp.ndarray
        A JAX array of shape (..., N*N), where each slice along the batch dimensions
        corresponds to the raveled outer product of the respective input array.
    '''
    # Dynamically detect batch dimensions
    batch_shape = arr.shape[:-1]  # All dimensions except the last one

    # Reshape to flatten batch dimensions for efficient `vmap`
    arr = arr.reshape((-1, arr.shape[-1]))

    # Vectorize over the flattened batch axis
    result = vmap(outer_ravel)(arr)

    # Reshape result back to original batch structure
    return result.reshape(batch_shape + (-1,))

outer_ravel(arr)

Compute the outer product of a 1D array and return it as a raveled 1D array.

Parameters:

Name Type Description Default
arr ndarray

A 1D JAX array.

required

Returns:

Type Description
ndarray

A 1D JAX array representing the raveled outer product of the input array.

Source code in jelli/utils/jax_helpers.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def outer_ravel(arr):
    '''
    Compute the outer product of a 1D array and return it as a raveled 1D array.

    Parameters
    ----------
    arr : jnp.ndarray
        A 1D JAX array.

    Returns
    -------
    jnp.ndarray
        A 1D JAX array representing the raveled outer product of the input array.
    '''
    return jnp.outer(arr, arr).ravel()