OJAX Quickstart#

The core component of OJAX is the ojax.OTree() class, which represents an immutable PyTree and uses Python dataclass field declaration. It is implemented using frozen dataclass which is a standard Python feature, and it serves as a base class for all custom JAX-compatible classes.

Declaring annotated fields#

Unlike standard Python class where fields are dynamically added via self.field_name = value, OTree adopts the annotated field syntax from Python dataclass and expects you to declare fields with type annotation. This is more in line with the functional paradigm employed by JAX.

Here is an example (code excerpt from the full example):

class Dense(OTree):
    input_features: int  # inferred to be auxiliary data
    output_features: int = aux()  # or explicit declaration
    weight: jnp.ndarray = field(default=..., init=False)  # inferred as child
    bias: jnp.ndarray = child(default=..., init=False)  # explicit declaration

Annotated fields can have the following patterns:

Note

Type annotation is required because Python dataclass uses it to identify fields. Attributes without type annotation are ignored by dataclass and become class variables instead of dataclass fields. OTree raises a warning for you in this case so don’t worry about accidentally making class variables instead of fields :)

If you want to create a class variable, declare it explicitly with typing.ClassVar so you won’t be nagged by this warning.

Note

Doing accurate type annotation is helpful but not required, since it is not checked by Python. This said, typing can help OTree to infer how to handle the fields as a PyTree. No need to think too hard though, declaring everything as typing.Any is also fine, and the PyTree handling can be specified explicitly in any case.

In the last pattern, ojax.child() / ojax.aux() / ojax.ignore() / ojax.ignore() are variants of dataclasses.field that also specifies how an OTree should handle a field as a PyTree. Let’s discuss this point further.

Field types for OTree#

PyTree is the data structure used by JAX to operate on data collections. It is composed of a definition part and a content part, and JAX operations act on the content part. OTree is registered as a PyTree, and thus should decide on how to handle its data fields. For this, fields in OTree are partitioned into four field types:

  • Auxiliary fields

    These are the fields that will be part of the PyTree definition. They are supposed to be static metadata that describe the characteristics of the OTree. They can be explicitly declared with ojax.aux().

  • Children fields

    These are the numerical content of the OTree which is the subject of JAX operations. They are typically JAX arrays and sub-PyTrees and can be marked explicitly with ojax.child().

  • Ignored fields

    These are dataclass fields that are omitted by the PyTree. They are declared with ojax.ignore().

  • Alien fields (since 3.1)

    These are dataclass fields that are incompatible with PyTree flattening. ojax.AlienException is raised when flattening is attempted on an Alien field holding a value that is not None. They are declared with ojax.alien().

Warning

Ignored fields are not preserved after the jax.tree.flatten then jax.tree.unflatten transformations. Since this combo is used by common JAX operations to handle PyTrees, ignored fields will easily get lost. Users should stick with auxiliary and children fields by default, or use alien fields to declare incompatible fields.

For fields without explicit field type declaration, OTree infers the field type based on the field annotation: subclasses of jax.Array and ojax.OTree are assumed to be child nodes, while the rest are assumed to be aux nodes. This inference logic is specified in the ojax.OTree.__infer_otree_field_type__() method and can be overridden by subclasses.

Warning

Non-OTree PyTrees such as lists of JAX arrays are not automatically detected as child nodes. The current inference logic only conservatively tackles the obvious case for your convenience. You need to explicitly handle child node declarations for more complex cases.

The __init__ method#

After the declaration comes the instantiation part. The following code segment from the full example shows an example __init__ method.

    # use .assign_ only in __init__ function
    def __init__(self, input_features: int, output_features: int):
        self.assign_(
            input_features=input_features, output_features=output_features
        )

One thing to note is that the usual self.field = value assignment pattern is no longer possible for OTree, as it is a frozen dataclass. Instead, OTree offers an .assign_ method to achieve this (inherited from the ojax.PureClass.assign_() method, where ojax.PureClass defines an immutable dataclass).

Note

Standard dataclasses allows for automatic generation of __init__ method one can use the dataclasses.__post_init__ method and InitVar to customize the initialization process of OTrees. While this might be conveient in simple cases, it requires more expertise with the Python dataclass structure and can be problematic especially with class inheritance (e.g., [1]). Moreover, it will silently prevent the inheritance __init__ method from parent classes with an automatic generation which is error-prone. Thus OJAX has disabled the automatic __init__ method generation feature (since 3.1).

Updating the fields#

During JAX computations, it is sometimes desirable to update the numerical fields of OTrees. To achieve this, OTree provides the ojax.OTree.update() method, which returns an updated OTree instance with the specified new numerical data for the children fields. The following code segment from the full example illustrates how it is used:

class Dense(OTree):
    def update_parameters(self, weight, bias):
        assert weight.shape == (self.output_features, self.input_features)
        assert bias.shape == (self.output_features,)
        return self.update(weight=weight, bias=bias)

ojax.OTree.update() preserves the PyTree structure and disallows updating auxiliary fields. This is usually the intended behavior since JAX operations don’t alter the PyTree structure either. It is also required for some arguments in JAX functions such as jax.lax.scan.

Note

None is a special empty PyTree container. Thus updating numerical fields to None or vice versa will change the PyTree structure and trigger an error from the .update() method.

In case you need to create a derived OTree with different auxiliary data or PyTree structure, the ojax.new() method should be used instead.

Note

Again, Python dataclass offers the dataclasses.replace method for updating dataclasses fields. However, it also does not preserve the PyTree structure, similar to ojax.new(). Furthermore, it relies on re-calling the __init__ method to generate the new instance and can be problematic for OTrees with custom __init__ method. ojax.OTree.update() and ojax.new() rely on shallow copy instead and don’t have such issue.

Warning

The .assign_ method introduced previously should not be used for updating the fields of OTrees. It is a low-level in-place operation that should only be used within the initialization functions. Otherwise the immutable paradigm is easily violated, potentially creating issues and subtle bugs for JAX code.

Adding JAX methods#

Of course, you are free to add custom methods just like how it is done for vanilla Python classes. Here is an example (another excerpt from the full example):

class Dense(OTree):
    def forward(self, input_array):
        return jnp.inner(input_array, self.weight) + self.bias

As an OTree is also a PyTree that JAX operations can handle, all its methods (which have self as their first argument) can directly work with JAX transforms including jax.jit.

While there is no easy way to enforce it in Python, the OTree methods should be pure JAX functions so as to avoid potential issues when working with JAX.

General coding tips#

To avoid “bad surprises” when coding with JAX / OJAX, always follow the two principles below:

  • Data should be persistent: no in-place operations, data should be immutable.

  • Functions should be pure: no side-effect, a function with the same argument should always return the same result and without other effects.

These principles are the basis of functional programming and are assumed by JAX, however it is not enforced in Python as an OOP language.

Additionally, to make your code jax.jit friendly, be mindful of control flows (e.g., if, for, while, etc.) which might crash or slow down the compilation, and consider using JAX structured control primitives.

OJAX helps you to fulfill these principles: it is designed to strongly encourage persistent data style and its codebase is jax.jit friendly.