diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index d630b3763ffa0aed06c1e5820cf2e4922408790b..7c114d2651a5a6bf10f088f04a28adb84c689e22 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -24,7 +24,6 @@ THE SOFTWARE. from abc import ABC, abstractmethod import sys -import dataclasses import islpy as isl import loopy as lp import pytools @@ -53,6 +52,7 @@ from pytato.tags import (ImplStored, ImplInlined, Named, PrefixNamed, from pytools.tag import Tag import pytato.reductions as red from pytato.codegen import _generate_name_for_temp +import attrs # set in doc/conf.py if getattr(sys, "_BUILDING_SPHINX_DOCS", False): @@ -76,6 +76,11 @@ __doc__ = """ .. autofunction:: add_store .. autofunction:: normalize_outputs .. autofunction:: get_initial_codegen_state + +.. class:: ReductionBounds + + A mapping from reduction inames to a tuple ``(lower_bound, upper_bound)``, + considered half-open. """ @@ -102,7 +107,7 @@ ReductionBounds = Mapping[str, Tuple[ScalarExpression, ScalarExpression]] # {{{ LoopyExpressionContexts -@dataclasses.dataclass(init=True, repr=False, eq=False) +@attrs.define(init=True, repr=False, eq=False) class PersistentExpressionContext(object): """ Mutable state used while generating :mod:`loopy` expressions for a @@ -127,7 +132,7 @@ class PersistentExpressionContext(object): """ state: CodeGenState _depends_on: FrozenSet[str] = \ - dataclasses.field(default_factory=frozenset) + attrs.field(factory=frozenset) @property def depends_on(self) -> FrozenSet[str]: @@ -137,7 +142,7 @@ class PersistentExpressionContext(object): self._depends_on = self._depends_on | other -@dataclasses.dataclass(frozen=True) +@attrs.define(frozen=True) class LocalExpressionContext: """ Records context being to be conveyed from a parent expression to its @@ -262,7 +267,7 @@ class InlinedResult(ImplementedResult): # {{{ SubstitutionRuleResult -@dataclasses.dataclass(frozen=True, eq=True) +@attrs.define(frozen=True, eq=True) class SubstitutionRuleResult(ImplementedResult): """ An array expression generated as a @@ -286,7 +291,7 @@ class SubstitutionRuleResult(ImplementedResult): # {{{ codegen state -@dataclasses.dataclass(init=True, repr=False, eq=False) +@attrs.define(init=True, repr=False, eq=False) class CodeGenState: """A container for data kept by :class:`CodeGenMapper`. @@ -308,10 +313,10 @@ class CodeGenState: _t_unit: lp.TranslationUnit results: Dict[Array, ImplementedResult] - var_name_gen: pytools.UniqueNameGenerator = dataclasses.field(init=False) - insn_id_gen: pytools.UniqueNameGenerator = dataclasses.field(init=False) + var_name_gen: pytools.UniqueNameGenerator = attrs.field(init=False) + insn_id_gen: pytools.UniqueNameGenerator = attrs.field(init=False) - def __post_init__(self) -> None: + def __attrs_post_init__(self) -> None: self.var_name_gen = self._t_unit.default_entrypoint.get_var_name_generator() self.insn_id_gen = ( self._t_unit.default_entrypoint.get_instruction_id_generator()) @@ -955,8 +960,7 @@ def get_initial_codegen_state(target: LoopyTarget, options=options, lang_version=lp.MOST_RECENT_LANGUAGE_VERSION) - return CodeGenState(_t_unit=kernel, - results={}) + return CodeGenState(t_unit=kernel, results={}) # {{{ generate_loopy