API Guide
matnets.MatrixParams
MatrixParams stores the weights and bias for matrix primitives:
from matnets import MatrixParams
Dense parameters use:
W: (q, p, n, n)
B: (q, n, n)
MatrixParams is registered as a JAX pytree, so it works with jax.jit,
jax.vmap, jax.grad, and nested dictionaries/lists of parameters.
matnets.init
params = matnets.init(key, p=2, q=3, n=4)
Creates:
params.W: (3, 2, 4, 4)
params.B: (3, 4, 4)
Weights use Glorot-uniform initialization. Bias starts at zero.
matnets.activations
from matnets.activations import (
relu, relud, leaky_relu, leaky_relud,
elu, elu_powered, elud,
sigmoid, sigmoidd,
tanh, tanhd,
softplus, softplusd
)
Standard and Determinant-Gated/Scaled Activations
MATNETS supports standard element-wise activations and determinant-based structural activations (either gated by sign or scaled by $1/n$-th root determinant, where $n$ is the matrix dimension).
Element-Wise Activations (Standard)
relu(x): Standard element-wise ReLU.leaky_relu(x, negative_slope=0.01): Standard element-wise leaky ReLU.elu(x, alpha=1.0): Standard element-wise ELU.sigmoid(x): Standard element-wise sigmoid.tanh(x): Standard element-wise tanh.softplus(x): Standard element-wise softplus.sss(x): Scaled squared sigmoid. Applies element-wise sigmoid, squares the resulting matrices, and scales them by $n^{-1}$.sst(x): Scaled squared tanh. Applies element-wise tanh, squares the resulting matrices, and scales them by $n^{-1}$.
Determinant-Gated Activations (Branching)
relud(x): Determinant-gated ReLU. Returns $X$ if $\text{det}(X) > 0$, else $0$.leaky_relud(x, negative_slope=0.01): Determinant-gated leaky ReLU. Returns $X$ if $\text{det}(X) > 0$, elsenegative_slope * X.elu_powered(x, alpha=1.0): Determinant-gated ELU (matrix exponential). Returns $X$ if $\text{det}(X) > 0$, elsealpha * (expm(X) - I).
Determinant-Scaled Activations (Smooth Scaling)
These functions scale the input matrix by $\text{fn}(\text{det}(X)^{1/n}) / \text{det}(X)^{1/n}$, using the $1/n$-th root of the determinant for dimension-normalized stability (with small-epsilon clamping on $\text{det}(X)$ for numerical safety):
elud(x, alpha=1.0): Scales by $\text{elu}(\text{det}(X)^{1/n}, \alpha) / \text{det}(X)^{1/n}$.sigmoidd(x): Scales by $\text{sigmoid}(\text{det}(X)^{1/n}) / \text{det}(X)^{1/n}$.tanhd(x): Scales by $\text{tanh}(\text{det}(X)^{1/n}) / \text{det}(X)^{1/n}$.softplusd(x): Scales by $\text{softplus}(\text{det}(X)^{1/n}) / \text{det}(X)^{1/n}$.
matnets.dense
y = matnets.dense(params, x)
Expected shapes:
params.W: (q, p, n, n)
params.B: (q, n, n)
x: (p, n, n)
y: (q, n, n)
With activation:
from matnets.activations import relud
y = matnets.dense(params, x, activation=relud)
The core operation is:
jnp.einsum("qpak,pkc->qac", params.W, x) + params.B
matnets.lax.matrix_conv1d
from matnets.lax import matrix_conv1d
y = matrix_conv1d(params, x, stride=1, padding="VALID")
Expected shapes:
params.W: (q, p, r, n, n)
params.B: (q, n, n)
x: (t, p, n, n)
y: (t_out, q, n, n)
r is the 1D kernel size.
matnets.lax.matrix_conv2d
from matnets.lax import matrix_conv2d
y = matrix_conv2d(params, x, stride=(1, 1), padding="SAME")
Expected shapes:
params.W: (q, p, h, w, n, n)
params.B: (q, n, n)
x: (height, width, p, n, n)
y: (height_out, width_out, q, n, n)
matnets.conv
Pooling primitives for downsampling sequential and grid data.
from matnets.conv import max_pool1d, maxd_pool1d, avg_pool1d, avgd_pool1d
from matnets.conv import max_pool2d, maxd_pool2d, avg_pool2d, avgd_pool2d
Standard pooling (max_pool, avg_pool, sum_pool) operates on matrix
elements. Determinant-based pooling selects or weights matrices based on their
determinant.
Standard Pooling
max_pool1d/2d: Element-wise maximum within the window.avg_pool1d/2d: Standard arithmetic mean of matrices in the window.sum_pool1d/2d: Standard sum of matrices in the window.
Determinant Pooling
maxd_pool1d/2d: Selects the single matrix in the window with the highest determinant.avgd_pool1d/2d: Computes $\sum \frac{1}{\text{det}(M)^{1/n}} M$ for all matrices $M$ in the window, where $n$ is the matrix dimension.
Expected 1D shapes:
x: (t, p, n, n) or (batch, t, p, n, n)
y: (t_out, p, n, n) or (batch, t_out, p, n, n)
Expected 2D shapes:
x: (y, x, p, n, n) or (batch, y, x, p, n, n)
y: (y_out, x_out, p, n, n) or (batch, y_out, x_out, p, n, n)
matnets.lax.matrix_attention
from matnets.lax import matrix_attention
out = matrix_attention(None, Q, K, V)
Expected token shapes:
Q: (tokens_q, p, n, n)
K: (tokens_k, p, n, n)
V: (tokens_k, p, n, n)
out: (tokens_q, p, n, n)
By default the score is a scaled Frobenius inner product. You can pass a custom
score_fn that receives one query token and one key token and returns a scalar.
If params is not None, each aggregated output token is projected through
matnets.dense(params, token).
matnets.nn
matnets.nn contains recurrent wiring patterns built from dense.
from matnets.nn import rnn_step, lstm_step, gru_step
These functions are intended to be used with jax.lax.scan.
RNN
carry, outputs = jax.lax.scan(
lambda h, x_t: rnn_step(params, h, x_t),
h0,
sequence,
)
LSTM
from matnets.activations import sss, sst
carry, outputs = jax.lax.scan(
lambda carry, x_t: lstm_step(
params,
carry,
x_t,
activations=(sss, sst), # Optional: defaults to (sigmoid, tanh)
),
(h0, c0),
sequence,
)
LSTM params must contain keys "i", "f", "g", and "o".
The lstm_step uses matrix multiplications (jnp.matmul) for the cell state and hidden state updates. You can pass a custom activations tuple to configure the gate and state activations (e.g. (gate_act, state_act)).
GRU
carry, outputs = jax.lax.scan(
lambda h, x_t: gru_step(params, h, x_t),
h0,
sequence,
)
GRU params must contain keys "z", "r", and "n".
matnets.utils
Data preprocessing utilities for MATNETS.
from matnets.utils import embed_pixels
import numpy as np
imgs = np.zeros((2, 10, 10, 3)) # (Batch, H, W, Channels)
windows = embed_pixels(imgs, n=3, spatial_axes=(1, 2), interleave=False)
# Shape: (2, 10, 10, 3, 3, 3)
embed_pixels extracts an n x n (or n for 1D) local neighborhood around
each element. The function automatically applies zero padding so the output
spatial dimensions match the input spatial dimensions, with the new window
dimensions appended to the end of the shape.
If interleave=True (or a tuple of booleans per axis), the order of elements
along the spatial axes is permuted according to an interleaved block pattern.
embed_sequence
from matnets.utils import embed_sequence
import numpy as np
seq = np.array([1, 2, 3, 4, 5]) # 1D Sequence: (T,)
out = embed_sequence(seq, n=3, axis=0)
# Shape: (5, 3, 3)
# For multiple channels / batch:
seq_mc = np.zeros((10, 5)) # (T, C)
out_mc = embed_sequence(seq_mc, n=3, axis=0)
# Shape: (10, 5, 3, 3)
embed_sequence extracts a symmetric n x n time-history embedding over a given time axis, ideal for audio, time-series, or other sequentially streaming data.
For every time step t along the target sequence axis, this backwardly extracts history up to n steps, constructing a symmetric matrix where distance from the diagonal corresponds naturally to the delay. Previous states prior to t=0 are strictly zero-padded.
It supports native mapping over 1D, 2D, and 3D data formats (e.g. (T,), (N, T), (N, T, C)) without interfering with non-sequential dimensions.