ojax package#
- class ojax.OTreeField(default, default_factory, init, repr, hash, compare, metadata, kw_only)[source]#
Abstract dataclasses.Field subclass for OTree fields.
- name#
- type#
- default#
- default_factory#
- repr#
- hash#
- init#
- compare#
- metadata#
- kw_only#
- class ojax.Aux(default, default_factory, init, repr, hash, compare, metadata, kw_only)[source]#
Field subclass for OTree auxiliary data that belong to
PyTreeDef
.- name#
- type#
- default#
- default_factory#
- repr#
- hash#
- init#
- compare#
- metadata#
- kw_only#
- class ojax.Child(default, default_factory, init, repr, hash, compare, metadata, kw_only)[source]#
Field subclass for a child PyTree node.
- name#
- type#
- default#
- default_factory#
- repr#
- hash#
- init#
- compare#
- metadata#
- kw_only#
- class ojax.Ignore(default, default_factory, init, repr, hash, compare, metadata, kw_only)[source]#
Field subclass that is ignored by PyTree creation
- name#
- type#
- default#
- default_factory#
- repr#
- hash#
- init#
- compare#
- metadata#
- kw_only#
- class ojax.Alien(default, default_factory, init, repr, hash, compare, metadata, kw_only)[source]#
Field subclass that is incompatible with PyTree flatten/unflatten
- name#
- type#
- default#
- default_factory#
- repr#
- hash#
- init#
- compare#
- metadata#
- kw_only#
- ojax.fields(otree: OTree | type[OTree], field_type: type[OTreeField] = None, infer: bool = True) tuple[Field, ...] [source]#
Convenience function extending
dataclasses.fields
that can filter fields by OTree field type.- Parameters:
otree – the OTree instance to examine.
field_type – if not None, specifies the field type to filter the list of fields from the OTree.
infer – determines if the field type should be inferred in case it is not available. Has no effect when
field_type = None
.
- Returns:
A tuple of fields from the given OTree.
- ojax.aux(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, **kwargs) Aux [source]#
Declares an OTree field that holds auxiliary data as part of PyTreeDef.
This function has identical arguments to
dataclasses.field
and returns a field of typeojax.Aux
.
- ojax.child(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, **kwargs) Child [source]#
Declares an OTree field that holds a child PyTree node.
This function has identical arguments to
dataclasses.field
and returns a field of typeojax.Child
.
- ojax.ignore(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, **kwargs) Ignore [source]#
Declares an OTree field that is ignored by PyTree creation.
This function has identical arguments to
dataclasses.field
and returns a field of typeojax.Ignore
.
- ojax.alien(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, **kwargs) Alien [source]#
Declares an OTree field that is not part of PyTree and crashes the flatten / unflatten operations.
This function has identical arguments to
dataclasses.field
and returns a field of typeojax.Alien
.
- class ojax.OTree[source]#
Base “object-like” class for JAX which bundles data (as an immutable PyTree) and pure functions.
The dataclass fields in this class are categorized into three field types:
auxiliary: metadata that defines the type of the PyTree. They can be non-numerical data and should stay static.
child: children of this PyTree that hold the numerical data. These are typically jax arrays and sub-PyTrees. JAX computations act on this part.
ignored: fields that are not part of the PyTree.
Warning
Ignored fields are not preserved with a flatten/unflatten transform, which is implicitly used by many JAX transforms and functions.
You can explicitly declare the category of each field with
ojax.aux()
,ojax.child()
andojax.ignore()
. Otherwise, it is inferred based on the annotated type of the field: subclasses ofjax.Array
andojax.OTree
are assumed to be child fields and the rest are assumed to be auxiliary fields. Example usage:class MyConv(ojax.OTree): out_channels: int # inferred to be aux kernel_size: int = ojax.aux() # declared to be aux weight: jax.numpy.ndarray # inferred to be child bias: jax.Array = ojax.child(default=None) # declared child
- update(**kwargs) OTree_T [source]#
Create a new version of this OTree instance with updated children.
This method only updates the numerical data and will keep the OTree structure and the metadata intact. It is the intended method to update the content without changing the PyTree type. If you need to create a new OTree with a different metadata / altered structure, use
ojax.new()
instead.Warning
None
is a special empty PyTree container. Thus updating numerical fields toNone
or vice versa will change the PyTree structure and trigger an error.- Parameters:
**kwargs – Keyword arguments specifying new values to be updated for the corresponding child fields.
- Returns:
New OTree instance with updated children.
- tree_flatten() tuple[tuple, tuple[tuple[tuple[str, int], ...], tuple]] [source]#
Define the flatten behavior of OTree as a PyTree.
- classmethod tree_unflatten(aux_data: tuple[tuple[tuple[str, int], ...], tuple], children: Sequence) OTree_T [source]#
Define the unflatten behavior of OTree as a PyTree.
- classmethod __infer_otree_field_type__(f: Field) type[OTreeField] [source]#
Infer the OJAX field type from a
dataclasses.Field
object.When f does not have specified OTree field type (not an instance of
ojax.OTreeField
), the annotatedf.type
is used for inference: for subclasses ofjax.Array
andojax.OTree
are assumed to be child fields and the rest are assumed to be auxiliary fields. You can override this method through subclass inheritance to change the inference logic.- Parameters:
f – a field of
OTree
.- Returns:
The inferred
ojax.OTreeField
.
- ojax.new(pure_obj: PureClass_T, **kwargs) PureClass_T [source]#
Shallow copy-based alternative to
dataclasses.replace()
.This function circumvents the instance creation with another
__init__
call. It allows direct updates ofinit=False
fields and avoids many “bad surprises” for custom__init__
functions.- Parameters:
pure_obj – An instance of
PureClass
to be updated.**kwargs – Keyword arguments specifying new values to be updated for the corresponding fields.
- Returns:
The updated instance.
- class ojax.PureClass[source]#
“Record-type” base class with immutable and annotated dataclass fields.
Direct attribute assignment with “=” is disabled to encourage the immutable paradigm. Use the
ojax.new
function to create new instances with updated field values instead of in-place updates.The
.assign_
method is provided to initialize fields in custom__init__
methods. It is the low-level impure “dark magic” that normally should not be used by the end user in any other context.- assign_(**kwargs) None [source]#
Low-level in-place setting of PureClass instance fields.
This should only be used during custom instance creation and before the first usage of the created instance. It is typically used in custom __init__ methods to replace the disabled direct assignment with “=”.
Warning
End users should avoid using this method directly in other cases as it can easily break the immutable paradigm and cause potential bugs with JAX.
- Parameters:
**kwargs – Keyword arguments specifying new values to be updated for the corresponding fields.