Example: NN layer with OJAX#
Here is a self-contained example showing how OJAX can be used to define a fully connected layer for deep learning. The highlighted regions below showcase two important characteristics of OJAX:
OJAX has seamless JAX integration: use JAX transforms and functions anywhere. And they work as intended.
OJAX is “pure like JAX”: no in-place update. New instances with updated states are always returned instead.
from dataclasses import field
import jax
import jax.numpy as jnp
from jax.random import PRNGKey, split as jrsplit, normal as jrnormal
from ojax import aux, child, OTree, fields
# defines a fully connected layer for neural networks
class Dense(OTree):
input_features: int # inferred to be auxiliary data
output_features: int = aux() # or explicit declaration
weight: jnp.ndarray = field(default=..., init=False) # inferred as child
bias: jnp.ndarray = child(default=..., init=False) # explicit declaration
# # optional here, automatically generated by the dataclass
# def __init__(self, input_features: int, output_features: int):
# self.assign_(
# input_features=input_features, output_features=output_features
# )
# forward pass
def forward(self, input_array):
return jnp.inner(input_array, self.weight) + self.bias
# set new parameters, notice it returns an updated version of itself
def update_parameters(self, weight, bias):
assert weight.shape == (self.output_features, self.input_features)
assert bias.shape == (self.output_features,)
return self.update(weight=weight, bias=bias)
# example usage
if __name__ == "__main__":
# define data
data_count, data_features, output_features = 4, 3, 2
key = PRNGKey(0)
key, key_data, key_weight, key_bias = jrsplit(key, 4)
input_data = jrnormal(key_data, shape=(data_count, data_features))
# define layer
init_weight = jrnormal(key_weight, shape=(output_features, data_features))
init_bias = jrnormal(key_bias, shape=(output_features,))
layer = Dense(data_features, output_features)
# No inplace update, need to get the returned updated layer instance!
layer = layer.update_parameters(weight=init_weight, bias=init_bias)
for f in fields(layer):
print(f.name, type(f), OTree.__infer_otree_field_type__(f))
# input_features <class 'dataclasses.Field'> <class 'ojax.otree.Aux'>
# output_features <class 'ojax.otree.Aux'> <class 'ojax.otree.Aux'>
# weight <class 'dataclasses.Field'> <class 'ojax.otree.Child'>
# bias <class 'ojax.otree.Child'> <class 'ojax.otree.Child'>
# use layer as a pytree
layer_w, layer_b = jax.tree.flatten(layer)[0]
assert (layer_w == init_weight).all() and (layer_b == init_bias).all()
# flatten and unflatten recovers the layer
layer = jax.tree.unflatten(*jax.tree.flatten(layer)[::-1])
# compute output, notice that jax.jit / jax.vmap works out of the box
output = jax.jit(jax.vmap(layer.forward))(input_data)
print(output)
# [[-2.666112 -1.0220472 ]
# [-3.701102 -0.8207982 ]
# [-4.4596996 0.6687442 ]
# [ 0.92416656 -3.302886 ]]
For a full-fledged NN library with module system, optimizers, interface with
impure codebase (e.g., dataloader and log), and fully jit
-able and
parallelizable high-level functions for NN training, stay tuned for
OJAX-NN.