OJAX Quickstart#

The core component of OJAX is the ojax.OTree() class, which represents an immutable PyTree and uses Python dataclass field declaration. It is implemented using frozen dataclass which is a standard Python feature, and it serves as a base class for all custom JAX-compatible classes.

Declaring annotated fields#

Unlike standard Python class where fields are dynamically added via self.field_name = value, OTree adopts the annotated field syntax from Python dataclass and expects you to declare fields with type annotation. This is more in line with the functional paradigm employed by JAX.

Here is an example (code excerpt from the full example):

class Dense(OTree):
    input_features: int  # inferred to be auxiliary data
    output_features: int = aux()  # or explicit declaration
    weight: jnp.ndarray = field(default=..., init=False)  # inferred as child
    bias: jnp.ndarray = child(default=..., init=False)  # explicit declaration

Annotated fields can have the following patterns:

Note

Type annotation is required because Python dataclass uses it to identify fields. Attributes without type annotation are ignored by dataclass and become class variables instead of dataclass fields. OTree raises a warning for you in this case so don’t worry about accidentally making class variables instead of fields :)

If you want to create a class variable, declare it explicitly with typing.ClassVar so you won’t be nagged by this warning.

Note

Doing accurate type annotation is helpful but not required, since it is not checked by Python. This said, typing can help OTree to infer how to handle the fields as a PyTree. No need to think too hard though, declaring everything as typing.Any is also fine, and the PyTree handling can be specified explicitly in any case.

In the last pattern, ojax.child() / ojax.aux() / ojax.ignore() are variants of dataclasses.field that also specifies how an OTree should handle a field as a PyTree. Let’s discuss this point further.

Field types for OTree#

PyTree is the data structure used by JAX to operate on data collections. It is composed of a definition part and a content part, and JAX operations act on the content part. OTree is registered as a PyTree, and thus should decide on how to handle its data fields. For this, fields in OTree are partitioned into three field types:

  • Auxiliary fields

    These are the fields that will be part of the PyTree definition. They are supposed to be static metadata that describe the characteristics of the OTree. They can be explicitly declared with ojax.aux().

  • Children fields

    These are the numerical content of the OTree which is the subject of JAX operations. They are typically JAX arrays and sub-PyTrees and can be marked explicitly with ojax.child().

  • Ignored fields

    These are dataclass fields that are not part of the PyTree. They are declared with ojax.ignore().

Warning

Ignore fields are not preserved after the jax.tree.flatten then jax.tree.unflatten transformations. Since this combo is used by common JAX operations to handle PyTrees, ignore fields will easily get lost. Users should stick with auxiliary and children fields by default.

For fields without explicit field type declaration, OTree infers the field type based on the field annotation: subclasses of jax.Array and ojax.OTree are assumed to be child nodes, while the rest are assumed to be aux nodes. This inference logic is specified in the ojax.OTree.__infer_otree_field_type__() method and can be overridden by subclasses.

Warning

Non-OTree PyTrees such as lists of JAX arrays are not automatically detected as child nodes. The current inference logic only conservatively tackles the obvious case for your convenience. You need to explicitly handle child node declarations for more complex cases.

The __init__ method#

After the declaration comes the instantiation part. Luckily Python dataclass can automatically generate an __init__ method based on the field declaration. This function directly assigns the arguments to the corresponding fields. For instance, the following commented code segment from the full example shows a custom __init__ method that implements what is automatically generated by the Python dataclass.

    # # optional here, automatically generated by the dataclass
    # def __init__(self, input_features: int, output_features: int):
    #     self.assign_(
    #         input_features=input_features, output_features=output_features
    #     )

This automatic generation is convenient for basic cases. However, it has its limits for complex cases such as additional checks or some inheritance cases [1]. In these cases, one should instead implement a custom __init__ method.

One thing to note is that the usual self.field = value assignment pattern is no longer possible for OTree, as it is a frozen dataclass. Instead, OTree offers an .assign_ method to achieve this (inherited from the ojax.PureClass.assign_() method, where ojax.PureClass defines an immutable dataclass).

Note

As a dataclass, it is also possible to use the dataclasses.__post_init__ method and InitVar to customize the initialization process of OTrees. However, this requires more expertise with the Python dataclass structure and can be less intuitive than a custom __init__ method for many Python users. Moreover, inheriting from a bass class with custom __init__ can be problematic.

Updating the fields#

During JAX computations, it is sometimes desirable to update the numerical fields of OTrees. To achieve this, OTree provides the ojax.OTree.update() method, which returns an updated OTree instance with the specified new numerical data for the children fields. The following code segment from the full example illustrates how it is used:

class Dense(OTree):
    def update_parameters(self, weight, bias):
        assert weight.shape == (self.output_features, self.input_features)
        assert bias.shape == (self.output_features,)
        return self.update(weight=weight, bias=bias)

ojax.OTree.update() preserves the PyTree structure and disallows updating auxiliary fields. This is usually the intended behavior since JAX operations don’t alter the PyTree structure either. It is also required for some arguments in JAX functions such as jax.lax.scan.

Note

None is a special empty PyTree container. Thus updating numerical fields to None or vice versa will change the PyTree structure and trigger an error from the .update() method.

In case you need to create a derived OTree with different auxiliary data or PyTree structure, the ojax.new() method should be used instead.

Note

Again, Python dataclass offers the dataclasses.replace method for updating dataclasses fields. However, it also does not preserve the PyTree structure, similar to ojax.new(). Furthermore, it relies on re-calling the __init__ method to generate the new instance and can be problematic for OTrees with custom __init__ method. ojax.OTree.update() and ojax.new() rely on shallow copy instead and don’t have such issue.

Warning

The .assign_ method introduced previously should not be used for updating the fields of OTrees. It is a low-level in-place operation that should only be used within the initialization functions. Otherwise the immutable paradigm is easily violated, potentially creating issues and subtle bugs for JAX code.

Adding JAX methods#

Of course, you are free to add custom methods just like how it is done for vanilla Python classes. Here is an example (another excerpt from the full example):

class Dense(OTree):
    def forward(self, input_array):
        return jnp.inner(input_array, self.weight) + self.bias

As an OTree is also a PyTree that JAX operations can handle, all its methods (which have self as their first argument) can directly work with JAX transforms including jax.jit.

While there is no easy way to enforce it in Python, the OTree methods should be pure JAX functions so as to avoid potential issues when working with JAX.

General coding tips#

To avoid “bad surprises” when coding with JAX / OJAX, always follow the two principles below:

  • Data should be persistent: no in-place operations, data should be immutable.

  • Functions should be pure: no side-effect, a function with the same argument should always return the same result and without other effects.

These principles are the basis of functional programming and are assumed by JAX, however it is not enforced in Python as an OOP language.