ojax package#
- class ojax.OTree[source]#
Base “object-like” class for JAX which bundles data (as an immutable PyTree) and pure functions.
The dataclass fields in this class are categorized into three field types:
auxiliary: metadata that defines the type of the PyTree. They can be non-numerical data and should stay static.
child: children of this PyTree that hold the numerical data. These are typically jax arrays and sub-PyTrees. JAX computations act on this part.
ignored: fields that are not part of the PyTree.
Warning
Ignored fields are not preserved with a flatten/unflatten transform, which is implicitly used by many JAX transforms and functions.
You can explicitly declare the category of each field with
ojax.aux()
,ojax.child()
andojax.ignore()
. Otherwise, it is inferred based on the annotated type of the field: subclasses ofjax.Array
andojax.OTree
are assumed to be child fields and the rest are assumed to be auxiliary fields. Example usage:class MyConv(ojax.OTree): out_channels: int # inferred to be aux kernel_size: int = ojax.aux() # declared to be aux weight: jax.numpy.ndarray # inferred to be child bias: jax.Array = ojax.child(default=None) # declared child
- update(**kwargs) OTree_T [source]#
Create a new version of this OTree instance with updated children.
This method only updates the numerical data and will keep the OTree structure and the metadata intact. It is the intended method to update the content without changing the PyTree type. If you need to create a new OTree with a different metadata / altered structure, use
ojax.new()
instead.Warning
None
is a special empty PyTree container. Thus updating numerical fields toNone
or vice versa will change the PyTree structure and trigger an error.- Parameters:
**kwargs – Keyword arguments specifying new values to be updated for the corresponding child fields.
- Returns:
New OTree instance with updated children.
- tree_flatten() tuple[tuple, tuple[tuple[tuple[str, int], ...], tuple]] [source]#
Define the flatten behavior of OTree as a PyTree.
- classmethod tree_unflatten(aux_data: tuple[tuple[tuple[str, int], ...], tuple], children: Sequence) OTree_T [source]#
Define the unflatten behavior of OTree as a PyTree.
- classmethod __infer_otree_field_type__(f: Field) FieldType [source]#
Infer the OJAX field type from a
dataclasses.Field
object.When
ojax.FieldType
is unspecified (whenget_field_type()
returnsNone
), the annotatedf.type
is used for inference: for subclasses ofjax.Array
andojax.OTree
are assumed to be child fields and the rest are assumed to be auxiliary fields. You can override this method through subclass inheritance to change the inference logic.- Parameters:
f – a field of
OTree
.- Returns:
The inferred
ojax.FieldType
.
- ojax.ignore(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None) Field [source]#
Declares an OTree field that is ignored by PyTree creation.
This function has identical arguments and return as the
dataclasses.field
function.
- ojax.aux(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None) Field [source]#
Declares an OTree field that holds auxiliary data as part of
PyTreeDef
.This function has identical arguments and return as the
dataclasses.field
function.
- ojax.child(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None)[source]#
Declares an OTree field that holds a child PyTree node.
This function has identical arguments and return as the
dataclasses.field
function.
- class ojax.FieldType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Choices for OTree field types.
- AUX = 'aux'#
- CHILD = 'child'#
- IGNORE = 'ignore'#
- ojax.fields(otree: OTree | type, field_type: FieldType | None = None, infer: bool = True) tuple[Field, ...] [source]#
Convenience function extending dataclasses.fields() that can filter fields by OTree field type.
- Parameters:
otree – the OTree instance to examine.
field_type – if not None, specifies the field type to filter the list of fields from OTree.
infer – determines if the field type should be inferred in case it is not available. Has no effect when
field_type = None
.
- Returns:
A tuple of fields from the given OTree.
- ojax.get_field_type(f: Field) FieldType | None [source]#
Retrieve the OJAX field type from a
dataclasses.Field
object.- Parameters:
f – a field of
ojax.OTree
.- Returns:
The
ojax.FieldType
if available, andNone
otherwise.
- ojax.new(pure_obj: PureClass_T, **kwargs) PureClass_T [source]#
Shallow copy-based alternative to
dataclasses.replace()
.This function circumvents the instance creation with another
__init__
call. It allows direct updates ofinit=False
fields and avoids many “bad surprises” for custom__init__
functions.- Parameters:
pure_obj – An instance of
PureClass
to be updated.**kwargs – Keyword arguments specifying new values to be updated for the corresponding fields.
- Returns:
The updated instance.
- class ojax.PureClass[source]#
“Record-type” base class with immutable and annotated dataclass fields.
Direct attribute assignment with “=” is disabled to encourage the immutable paradigm. Use the
ojax.new
function to create new instances with updated field values instead of in-place updates.The
.assign_
method is provided to initialize fields in custom__init__
methods. It is the low-level impure “dark magic” that normally should not be used by the end user in any other context.- assign_(**kwargs) None [source]#
Low-level in-place setting of PureClass instance fields.
This should only be used during custom instance creation and before the first usage of the created instance. It is typically used in custom __init__ methods to replace the disabled direct assignment with “=”.
Warning
End users should avoid using this method directly in other cases as it can easily break the immutable paradigm and cause potential bugs with JAX.
- Parameters:
**kwargs – Keyword arguments specifying new values to be updated for the corresponding fields.