Source code for ojax.pureclass
"""Customized frozen dataclass for immutable computation."""
from __future__ import annotations
from typing import TypeVar
from typing_extensions import dataclass_transform
import warnings
import copy
from dataclasses import dataclass, fields
[docs]
class NoAnnotationWarning(UserWarning):
pass
def _is_magic_name(s: str) -> bool:
return s.startswith("__") and s.endswith("__")
# get non property, non magical and non callable class variables
def _get_class_vars(cls: type) -> list[str]:
return [
m
for m, v in cls.__dict__.items()
if not (
callable(getattr(cls, m))
or _is_magic_name(m)
or isinstance(v, property)
)
]
# warn user about non-annotated class variables ambiguous for dataclasses
def _warn_no_anno_class_attrs(cls: type) -> None:
anno = cls.__annotations__
no_annos = tuple(n for n in _get_class_vars(cls) if n not in anno)
if not no_annos:
return
warnings.warn(
"Non-annotated class attributes are ignored by dataclass "
f"{cls.__name__}: {no_annos}. Consider adding annotations and "
"declaring class variables explicitly with typing.ClassVar instead.",
NoAnnotationWarning,
)
[docs]
@dataclass_transform(frozen_default=True)
@dataclass(frozen=True, init=False)
class PureClass:
""" "Record-type" base class with immutable and annotated dataclass fields.
Direct attribute assignment with "=" is disabled to encourage the immutable
paradigm. Use the ``ojax.new`` function to create new instances with
updated field values instead of in-place updates.
The ``.assign_`` method is provided to initialize fields in custom
``__init__`` methods. It is the low-level impure "dark magic" that normally
should not be used by the end user in any other context.
"""
def __init_subclass__(cls, **kwargs):
"""Make each subclass a frozen dataclass."""
# warn user about potential missing annotation error
_warn_no_anno_class_attrs(cls)
return dataclass(frozen=True, init=False, **kwargs)(cls)
[docs]
def assign_(self, **kwargs) -> None:
"""Low-level in-place setting of PureClass instance fields.
This should only be used during custom instance creation and before the
first usage of the created instance. It is typically used in custom
__init__ methods to replace the disabled direct assignment with "=".
.. warning::
End users should avoid using this method directly in other cases as
it can easily break the immutable paradigm and cause potential bugs
with JAX.
Args:
**kwargs: Keyword arguments specifying new values to be updated for
the corresponding fields.
"""
field_names = set(f.name for f in fields(self))
arg_names = set(kwargs.keys())
if not arg_names.issubset(field_names):
raise ValueError(
f"Unrecognized fields: {arg_names.difference(field_names)}"
)
for k, v in kwargs.items():
object.__setattr__(self, k, v)
PureClass_T = TypeVar("PureClass_T", bound=PureClass)
[docs]
def new(pure_obj: PureClass_T, **kwargs) -> PureClass_T:
"""Shallow copy-based alternative to ``dataclasses.replace()``.
This function circumvents the instance creation with another ``__init__``
call. It allows direct updates of ``init=False`` fields and avoids many
"bad surprises" for custom ``__init__`` functions.
Args:
pure_obj: An instance of ``PureClass`` to be updated.
**kwargs: Keyword arguments specifying new values to be updated for
the corresponding fields.
Returns:
The updated instance.
"""
new_obj = copy.copy(pure_obj)
new_obj.assign_(**kwargs)
return new_obj