OJAX: make JAX “classy” again!
JAX, a Python-based numerical computing library, does not pair well with Python classes?! Fret not, the OJAX library provides a solution to this rather awkward situation.
The functional paradigm used by JAX is both a blessing and a curse: it enables powerful transformations like vmap
and jit
, but also expects conformity to immutable and pure programming, which is not the usual OOP style of Python. As this is not enforced in Python, writing JAX code can easily run into sharp edges and potentially subtle bugs for the unwary, especially if one ventures into using Python classes which are a core OOP structure.
This post introduces some contexts and motivations for OJAX, a Python library I wrote which is a small extension of JAX. OJAX proposes a “class-like” base type that combines the best of both worlds:
- It is simple to use and flexible like vanilla Python classes. And it supports inheritance for code reuse.
- It is compatible with JAX transformations and is immutable like JAX arrays. It simply works with any JAX code without hassle.
Issues with JAX + class
Directly mingling OOP-style Python classes with FP-style JAX code can lead to hiccups. A typical issue many new JAX users will encounter is the fact that methods in vanilla Python classes cannot be jit
ted. Below is an example from the JAX FAQ section discussing this issue:
>>> import jax.numpy as jnp
>>> from jax import jit
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit # <---- How to do this correctly?
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
>>> c = CustomClass(2, True)
>>> c.calc(3)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
The same FAQ section offers three solutions to fix the jit
issue. However, they are more like patches that introduce overhead for JAX programs using Python classes, and a discouraging reminder that this might not be the best practice.
To class
or not to class
?
Given the incompatibilities between JAX and vanilla Python classes, should we simply avoid using Python classes when doing JAX coding? If not, how can we introduce classes into JAX code without worries and hassles? Let’s have a closer look in this section.
The function-only style
First of all, is it possible to implement JAX projects without using a Python class
at all?
The answer is yes. You can actually write complex JAX code without using classes at all. A prominent example is the Stax deep learning library which is entirely composed of Python functions.
This is typically achieved with closures (or functionals), which are basically functions that return JAX functions. The arguments passed to the outer closures serve to configure the returned JAX functions.
Is this a good style? Judging from my personal experience, I would say no. In fact, this was the style I used for my first serious JAX project which implemented an early version of the IVON optimizer (It won first place in the NeurIPS BDL competition in 2021). My main gripes were:
- It is harder to reuse code, since the content inside functions is not easily accessible.
- Closures typically configure their returned JAX functions with their contexts. Such contexts are implicit and hard to examine.
- Functions and data are separated, making things harder to manage. Complex data such as neural network weights tend to be packed into a monolithic black-box and juggled among the functions.
Some existing alternatives
Given the popularity of deep learning applications, many solutions have been proposed to improve the inconvenient function-only style.
-
Use NamedTuple instead of classes:
Named tuples can be tackled by JAX as tuples and can have methods like vanilla classes. However, these approaches can be too simplistic: named tuples don’t support inheritance or custom initialization, and they introduce an index over the fields which is often meaningless.
-
OOP declaration, then transform to FP:
This is adopted by Haiku. It has the advantage that users can define neural network architectures with the familiar OOP classes.1 This, however, creates an incompatible barrier between model definition and execution, and many JAX features cannot be directly used in the OOP part.
-
Custom classes as PyTrees:
This is used by libraries like Flax and Equinox. This is a more flexible approach in general and is also used by OJAX. However, many such cases are designed specifically for deep learning. And they typically offer mutable assignments, which tends to do more harm than good for JAX programs. There is a reason why JAX uses immutable arrays!
Desiderata of a JAX “bass class”
While a variety of options exist for JAX programming, what motivates me to implement OJAX is the fact that a good JAX “bass class” is still missing, making JAX code either cumbersome or hacky.
For me, a good JAX bass class should fulfill three desiderata:
-
Simple to understand and use.
It should follow the zen of Python: offer one obvious way to do things right, and be simple and explicit. It should also discourage users from doing things wrong.
-
Flexible custom classes for general JAX computation.
Especially, it should be easy to use beyond defining neural network layers. After all, JAX is a general numerical computing library just like NumPy.
-
Compatible with JAX and its functional paradigm.
It should seamlessly integrate into JAX code and follow the same functional principles. It should feel like a natural extension of JAX where all transforms and operations simply work as expected.
OJAX is my take on defining how a JAX base class should be. It should hopefully fulfill these desiderata and provide a natural way to structure custom JAX code.
Core ideas of OJAX
OJAX offers the OTree
class as a base class that extends the JAX API. It is similar to the “record” type in FP languages like OCaml and Haskell. OTree
combines two core components: frozen dataclass to emulate the immutable “record” type, and PyTree for JAX integration.
Dataclass: benefits and quirks
Frozen dataclasses from Python offer a convenient way to emulate the “record” type in FP languages. One key difference to vanilla Python classes is that dataclasses make annotated field declarations:
import dataclasses
# CustomClass, frozen dataclass version
@dataclasses.dataclass(frozen=True)
class CustomClass:
x: jnp.ndarray
mul: bool
...
Python dataclasses offer automatic generation of __init__
function whose signature is based on the annotated field declaration. Moreover, they support inheritance and field updates with the dataclasses.replace()
function. Problem solved, right?
Unfortunately, not quite. The automatic __init__
generation is limited and the dataclass-specific way to customize it (__post_init__
and InitVar
) is not as flexible and steepens the learning curve. It is simply easier to write custom __init__
functions for more advanced usages.
To use custom __init__
functions, however, it turns out there are two additional things to take care of:
-
dataclasses.replace()
does not pair well with custom__init__
, as it assumes generated__init__
and uses it to generate the updated instance. OJAX offers thenew()
function as an alternative, which is based on the “copy-on-write” principle and uses shallow copy instead of recalling the__init__
function. -
The usual
self.arg = val
assignment is forbidden by frozen dataclasses. For this, OJAX uses aPureClass
class which offers a.assign_()
method. Also, it makes all its subclasses frozen dataclasses so users don’t need to write@dataclasses.dataclass(frozen=True)
everywhere.
import ojax
# CustomClass, ojax.PureClass version
class CustomClass(ojax.PureClass):
x: jnp.ndarray
mul: bool
# optional in this case, can be auto-generated
def __init__(self, x: jnp.ndarray, mul: bool):
self.assign_(x=x, mul=mul)
...
The underscore should remind you that
.assign_
is a low-level in-place “dark magic” that should not be used outside of initialization functions. Especially, it is not a method intended for updating parameters in the middle of JAX computations! For this, useOTree.update()
instead which is introduced in the following.
PyTree: how JAX handles data
There is still a step missing: we should make PureClass
compatible with JAX. For this, we need to register it as a JAX PyTree, which is the data structure JAX can work with. OJAX achieves this with the OTree
class which inherits from ojax.PureClass
.
A JAX PyTree is composed of two parts:
-
Definition (PyTreeDef): this part encodes the structure of the PyTree and additional metadata.
-
Content (PyTree leaves): this is the main data that JAX computations will act on and are typically JAX arrays.
For PyTree registration, we should figure out which part each dataclass field belongs to. Luckily, we can already make a guess based on the type annotation: non-numerical types like strings are usually metadata, and JAX arrays and OTree subclasses probably encode main content.
Ambiguous cases do exist: a float
scalar can be a static configuration like dropout rate or a scalar model parameter to be updated. Thus OTree allows users to declare the intention explicitly, overriding a simple default inference that exists for convenience:
- OJAX provides
aux()
/child()
field functions to declare fields as metadata or containing main content. For completeness, anignore()
function excludes a field from PyTree registration. - By default, OTree assumes subclasses of
jax.Array
andojax.OTree
are children nodes holding main content, and the rest is assumed to be metadata. Subclasses can alter this logic by overriding the.__infer_otree_field_type__()
method.
import ojax
# CustomClass, ojax.OTree explicit declaration version
class CustomClass(ojax.PureClass):
x: jnp.ndarray = ojax.child()
mul: bool = ojax.aux()
...
Moreover, OTree has a .update()
method to update its content without changing the definition part. This should be used for field updates by default. ojax.new()
should only be used in case one explicitly wishes to change the “type” of the PyTree.
How OJAX looks like
In summary, ojax.OTree
is a “jaxified” base class type. It effectively makes all its subclasses immutable PyTrees which can be bundled with pure JAX methods. Revisiting the “CustomClass” example, this is how it works with OJAX:
>>> import jax.numpy as jnp
>>> from jax import jit
>>> import ojax
>>> # CustomClass, ojax.OTree version
>>> class CustomClass(ojax.OTree):
... x: jnp.ndarray
... mul: bool
...
... @jit
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
>>> c = CustomClass(2, True)
>>> c.calc(3)
Array(6, dtype=int32, weak_type=True)
We see that it works as expected without any bad surprises!
Conclusion
If you are interested in this library, feel free to install it directly with pip
(make sure JAX for your hardware is already installed):
pip install ojax
I have also written a documentation for OJAX, which contains a quickstart guide and a standalone example, feel free to check it out!
Until next time
-
In fact, PyTorch 2.0+ also employs this approach for the
torch.func
part to introduce JAX-like transformations such asvmap
. ↩