jelli.utils.jax_helpers
batched_outer_ravel(arr)
Compute the outer product for each 1D array in a batch and return them as raveled 1D arrays.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr
|
ndarray
|
A JAX array of shape (..., N), where ... represents any number of batch dimensions |
required |
Returns:
Type | Description |
---|---|
ndarray
|
A JAX array of shape (..., N*N), where each slice along the batch dimensions corresponds to the raveled outer product of the respective input array. |
Source code in jelli/utils/jax_helpers.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
|
outer_ravel(arr)
Compute the outer product of a 1D array and return it as a raveled 1D array.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr
|
ndarray
|
A 1D JAX array. |
required |
Returns:
Type | Description |
---|---|
ndarray
|
A 1D JAX array representing the raveled outer product of the input array. |
Source code in jelli/utils/jax_helpers.py
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
|