OJAX Quickstart#
The core component of OJAX is the ojax.OTree()
class, which represents
an immutable PyTree and uses Python dataclass field declaration. It is
implemented using frozen dataclass which is a standard Python feature, and
it serves as a base class for all custom JAX-compatible classes.
Declaring annotated fields#
Unlike standard Python class
where fields are dynamically added via
self.field_name = value
, OTree adopts the annotated field syntax from
Python dataclass and expects you to declare fields with type annotation.
This is more in line with the functional paradigm employed by JAX.
Here is an example (code excerpt from the full example):
class Dense(OTree):
input_features: int # inferred to be auxiliary data
output_features: int = aux() # or explicit declaration
weight: jnp.ndarray = field(default=..., init=False) # inferred as child
bias: jnp.ndarray = child(default=..., init=False) # explicit declaration
Annotated fields can have the following patterns:
field_name: type
field_name: type = default_value
field_name: type = dataclasses.field()
with optional dataclasses.field arguments such asdefault
/default_factory
andinit
field_name: type =
ojax.child()
/ojax.aux()
/ojax.ignore()
/ojax.alien()
with the same optional arguments as dataclasses.field
Note
Type annotation is required because Python dataclass uses it to identify fields. Attributes without type annotation are ignored by dataclass and become class variables instead of dataclass fields. OTree raises a warning for you in this case so don’t worry about accidentally making class variables instead of fields :)
If you want to create a class variable, declare it explicitly with
typing.ClassVar
so you won’t be nagged by this warning.
Note
Doing accurate type annotation is helpful but not required, since it is not
checked by Python. This said, typing can help OTree to infer how to handle
the fields as a PyTree. No need to think too hard though, declaring
everything as typing.Any
is also fine, and the PyTree handling can be
specified explicitly in any case.
In the last pattern, ojax.child()
/ ojax.aux()
/
ojax.ignore()
/ ojax.ignore()
are variants of
dataclasses.field
that also specifies how an OTree should handle a field
as a PyTree. Let’s discuss this point further.
Field types for OTree#
PyTree is the data structure used by JAX to operate on data collections. It is composed of a definition part and a content part, and JAX operations act on the content part. OTree is registered as a PyTree, and thus should decide on how to handle its data fields. For this, fields in OTree are partitioned into four field types:
- Auxiliary fields
These are the fields that will be part of the PyTree definition. They are supposed to be static metadata that describe the characteristics of the OTree. They can be explicitly declared with
ojax.aux()
.
- Children fields
These are the numerical content of the OTree which is the subject of JAX operations. They are typically JAX arrays and sub-PyTrees and can be marked explicitly with
ojax.child()
.
- Ignored fields
These are dataclass fields that are omitted by the PyTree. They are declared with
ojax.ignore()
.
- Alien fields (since 3.1.0)
These are dataclass fields that are incompatible with PyTree flattening / unflattening, raising
ojax.AlienException
when such operations are attempted. They are declared withojax.alien()
.
Warning
Ignored fields are not preserved after the jax.tree.flatten
then
jax.tree.unflatten
transformations. Since this combo is used by common
JAX operations to handle PyTrees, ignored fields will easily get lost. Users
should stick with auxiliary and children fields by default, or use alien
fields to declare incompatible fields.
For fields without explicit field type declaration, OTree infers the field type
based on the field annotation: subclasses of jax.Array
and
ojax.OTree
are assumed to be child nodes, while the rest are assumed
to be aux nodes. This inference logic is specified in the
ojax.OTree.__infer_otree_field_type__()
method and can be overridden by
subclasses.
Warning
Non-OTree PyTrees such as lists of JAX arrays are not automatically detected as child nodes. The current inference logic only conservatively tackles the obvious case for your convenience. You need to explicitly handle child node declarations for more complex cases.
The __init__
method#
After the declaration comes the instantiation part. Luckily Python dataclass
can automatically generate an __init__
method based on the field
declaration. This function directly assigns the arguments to the corresponding
fields. For instance, the following commented code segment from the
full example shows a custom __init__
method that
implements what is automatically generated by the Python dataclass.
# # optional here, automatically generated by the dataclass
# def __init__(self, input_features: int, output_features: int):
# self.assign_(
# input_features=input_features, output_features=output_features
# )
This automatic generation is convenient for basic cases. However, it has its
limits for complex cases such as additional checks or some inheritance cases
[1]. In these cases, one should instead implement a custom __init__
method.
One thing to note is that the usual self.field = value
assignment pattern
is no longer possible for OTree, as it is a frozen dataclass. Instead, OTree
offers an .assign_
method to achieve this (inherited from the
ojax.PureClass.assign_()
method, where ojax.PureClass
defines an
immutable dataclass).
Note
As a dataclass, it is also possible to use the
dataclasses.__post_init__
method and InitVar to customize the
initialization process of OTrees. However, this requires more expertise with
the Python dataclass structure and can be less intuitive than a custom
__init__
method for many Python users. Moreover, inheriting from a bass
class with custom __init__
can be problematic.
Updating the fields#
During JAX computations, it is sometimes desirable to update the numerical
fields of OTrees. To achieve this, OTree provides the ojax.OTree.update()
method, which returns an updated OTree instance with the specified new
numerical data for the children fields. The following code segment from the
full example illustrates how it is used:
class Dense(OTree):
def update_parameters(self, weight, bias):
assert weight.shape == (self.output_features, self.input_features)
assert bias.shape == (self.output_features,)
return self.update(weight=weight, bias=bias)
ojax.OTree.update()
preserves the PyTree structure and disallows updating
auxiliary fields. This is usually the intended behavior since JAX operations
don’t alter the PyTree structure either. It is also required for some arguments
in JAX functions such as jax.lax.scan
.
Note
None
is a special empty PyTree container. Thus updating numerical fields
to None
or vice versa will change the PyTree structure and trigger an
error from the .update()
method.
In case you need to create a derived OTree with different auxiliary data or
PyTree structure, the ojax.new()
method should be used instead.
Note
Again, Python dataclass offers the dataclasses.replace
method for
updating dataclasses fields. However, it also does not preserve the PyTree
structure, similar to ojax.new()
. Furthermore, it relies on re-calling
the __init__
method to generate the new instance and can be problematic
for OTrees with custom __init__
method. ojax.OTree.update()
and
ojax.new()
rely on shallow copy instead and don’t have such issue.
Warning
The .assign_
method introduced previously should not be used for updating
the fields of OTrees. It is a low-level in-place operation that should only
be used within the initialization functions. Otherwise the immutable paradigm
is easily violated, potentially creating issues and subtle bugs for JAX code.
Adding JAX methods#
Of course, you are free to add custom methods just like how it is done for vanilla Python classes. Here is an example (another excerpt from the full example):
class Dense(OTree):
def forward(self, input_array):
return jnp.inner(input_array, self.weight) + self.bias
As an OTree is also a PyTree that JAX operations can handle, all its methods
(which have self
as their first argument) can directly work with JAX
transforms including jax.jit
.
While there is no easy way to enforce it in Python, the OTree methods should be pure JAX functions so as to avoid potential issues when working with JAX.
General coding tips#
To avoid “bad surprises” when coding with JAX / OJAX, always follow the two principles below:
Data should be persistent: no in-place operations, data should be immutable.
Functions should be pure: no side-effect, a function with the same argument should always return the same result and without other effects.
These principles are the basis of functional programming and are assumed by JAX, however it is not enforced in Python as an OOP language.