Examples
Basic Forward Pass
.\.venv\Scripts\python.exe examples\basic_forward.py
This creates one dense layer and applies it to an input shaped (2, 2, 2).
Five Hidden Layers
.\.venv\Scripts\python.exe examples\five_hidden_net.py
This example defines a small class:
model = FiveHiddenNet(key, input_neurons=3, hidden_neurons=4, n=2)
y = jax.jit(model.forward)(model.params, x)
The output shape is (1, 2, 2).
Architecture Walkthrough
.\.venv\Scripts\python.exe examples\matrix_architectures.py
This file checks shape flow through:
- MLP
- batched MLP with
jax.vmap - gradients with
jax.grad - RNN with
jax.lax.scan - LSTM with
jax.lax.scan - Frobenius attention
- residual block
Pooling and CNNs
.\.venv\Scripts\python.exe examples\10_pooling.py
Demonstrates downsampling a 1D sequence using maxd_pool1d and avgd_pool1d.
Shows how pooling integrates into a standard Flax model with matrix
convolutions.
Activations
.\.venv\Scripts\python.exe examples\11_activations.py
Contrasts element-wise relu with structural relud and elud, showing how
determinant-gating preserves or filters entire matrix-neurons.
Where Parallelization Happens
The dense operation:
jnp.einsum("qpak,pkc->qac", W, x)
is the main kernel. JAX can compile it with jit, map it over batches or token
sequences with vmap, differentiate it with grad, and call it repeatedly
inside lax.scan.