ojax package#

class ojax.OTree[source]#

Base “object-like” class for JAX which bundles data (as an immutable PyTree) and pure functions.

The dataclass fields in this class are categorized into three field types:

  • auxiliary: metadata that defines the type of the PyTree. They can be non-numerical data and should stay static.

  • child: children of this PyTree that hold the numerical data. These are typically jax arrays and sub-PyTrees. JAX computations act on this part.

  • ignored: fields that are not part of the PyTree.

Warning

Ignored fields are not preserved with a flatten/unflatten transform, which is implicitly used by many JAX transforms and functions.

You can explicitly declare the category of each field with ojax.aux(), ojax.child() and ojax.ignore(). Otherwise, it is inferred based on the annotated type of the field: subclasses of jax.Array and ojax.OTree are assumed to be child fields and the rest are assumed to be auxiliary fields. Example usage:

class MyConv(ojax.OTree):
    out_channels: int  # inferred to be aux
    kernel_size: int = ojax.aux()  # declared to be aux
    weight: jax.numpy.ndarray  # inferred to be child
    bias: jax.Array = ojax.child(default=None)  # declared child
update(**kwargs) OTree_T[source]#

Create a new version of this OTree instance with updated children.

This method only updates the numerical data and will keep the OTree structure and the metadata intact. It is the intended method to update the content without changing the PyTree type. If you need to create a new OTree with a different metadata / altered structure, use ojax.new() instead.

Warning

None is a special empty PyTree container. Thus updating numerical fields to None or vice versa will change the PyTree structure and trigger an error.

Parameters:

**kwargs – Keyword arguments specifying new values to be updated for the corresponding child fields.

Returns:

New OTree instance with updated children.

tree_flatten() tuple[tuple, tuple[tuple[tuple[str, int], ...], tuple]][source]#

Define the flatten behavior of OTree as a PyTree.

classmethod tree_unflatten(aux_data: tuple[tuple[tuple[str, int], ...], tuple], children: Sequence) OTree_T[source]#

Define the unflatten behavior of OTree as a PyTree.

classmethod __infer_otree_field_type__(f: Field) FieldType[source]#

Infer the OJAX field type from a dataclasses.Field object.

When ojax.FieldType is unspecified (when get_field_type() returns None), the annotated f.type is used for inference: for subclasses of jax.Array and ojax.OTree are assumed to be child fields and the rest are assumed to be auxiliary fields. You can override this method through subclass inheritance to change the inference logic.

Parameters:

f – a field of OTree.

Returns:

The inferred ojax.FieldType.

ojax.ignore(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None) Field[source]#

Declares an OTree field that is ignored by PyTree creation.

This function has identical arguments and return as the dataclasses.field function.

ojax.aux(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None) Field[source]#

Declares an OTree field that holds auxiliary data as part of PyTreeDef.

This function has identical arguments and return as the dataclasses.field function.

ojax.child(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None)[source]#

Declares an OTree field that holds a child PyTree node.

This function has identical arguments and return as the dataclasses.field function.

class ojax.FieldType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Choices for OTree field types.

AUX = 'aux'#
CHILD = 'child'#
IGNORE = 'ignore'#
ojax.fields(otree: OTree | type, field_type: FieldType | None = None, infer: bool = True) tuple[Field, ...][source]#

Convenience function extending dataclasses.fields() that can filter fields by OTree field type.

Parameters:
  • otree – the OTree instance to examine.

  • field_type – if not None, specifies the field type to filter the list of fields from OTree.

  • infer – determines if the field type should be inferred in case it is not available. Has no effect when field_type = None.

Returns:

A tuple of fields from the given OTree.

ojax.get_field_type(f: Field) FieldType | None[source]#

Retrieve the OJAX field type from a dataclasses.Field object.

Parameters:

f – a field of ojax.OTree.

Returns:

The ojax.FieldType if available, and None otherwise.

ojax.new(pure_obj: PureClass_T, **kwargs) PureClass_T[source]#

Shallow copy-based alternative to dataclasses.replace().

This function circumvents the instance creation with another __init__ call. It allows direct updates of init=False fields and avoids many “bad surprises” for custom __init__ functions.

Parameters:
  • pure_obj – An instance of PureClass to be updated.

  • **kwargs – Keyword arguments specifying new values to be updated for the corresponding fields.

Returns:

The updated instance.

class ojax.PureClass[source]#

“Record-type” base class with immutable and annotated dataclass fields.

Direct attribute assignment with “=” is disabled to encourage the immutable paradigm. Use the ojax.new function to create new instances with updated field values instead of in-place updates.

The .assign_ method is provided to initialize fields in custom __init__ methods. It is the low-level impure “dark magic” that normally should not be used by the end user in any other context.

assign_(**kwargs) None[source]#

Low-level in-place setting of PureClass instance fields.

This should only be used during custom instance creation and before the first usage of the created instance. It is typically used in custom __init__ methods to replace the disabled direct assignment with “=”.

Warning

End users should avoid using this method directly in other cases as it can easily break the immutable paradigm and cause potential bugs with JAX.

Parameters:

**kwargs – Keyword arguments specifying new values to be updated for the corresponding fields.

exception ojax.NoAnnotationWarning[source]#