diff --git a/doc/conf.py b/doc/conf.py index 7bd8b3030701fbe4393a068cff997dea0c7211df..96130ca4a3b32adc496db9cbae054a3e773b6749 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -65,4 +65,5 @@ intersphinx_mapping = { 'http://documen.tician.de/pymbolic/': None, 'http://documen.tician.de/loopy/': None, 'http://documen.tician.de/sumpy/': None, + 'http://documen.tician.de/islpy/': None, } diff --git a/doc/design.rst b/doc/design.rst index 2309491944833a1f3b4396281daf91e51096c985..fd50da9160a4bc4b8cb7531a38a31d0dc1d7277d 100644 --- a/doc/design.rst +++ b/doc/design.rst @@ -119,6 +119,8 @@ Reserved Identifiers - ``_pt_shp``: Used to automatically generate identifiers used in data-dependent shapes. + - ``_pt_out``: The default name of an unnamed output argument + - Identifiers used in index lambdas are also reserved. These include: - Identifiers matching the regular expression ``_[0-9]+``. They are used diff --git a/doc/reference.rst b/doc/reference.rst index 522bdd4ba3f3e69cb197157e78801d9be06f42c9..43f655519a9e240bafef04372771f415ffed42a3 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -2,3 +2,8 @@ Reference ========= .. automodule:: pytato.array +.. automodule:: pytato.scalar_expr +.. automodule:: pytato.transform +.. automodule:: pytato.program +.. automodule:: pytato.target +.. automodule:: pytato.codegen diff --git a/pytato/__init__.py b/pytato/__init__.py index 4cef6a4a58a5fbacf76340cefe8574fdb45446de..123041a3844465c966d95e91840c27307f225b0a 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -29,5 +29,13 @@ from pytato.array import ( DottedName, Placeholder, make_placeholder, ) -__all__ = ("DottedName", "Namespace", "Array", "DictOfNamedArrays", - "Tag", "UniqueTag", "Placeholder", "make_placeholder") +from pytato.codegen import generate_loopy +from pytato.target import Target, PyOpenCLTarget + +__all__ = ( + "DottedName", "Namespace", "Array", "DictOfNamedArrays", + "Tag", "UniqueTag", "Placeholder", "make_placeholder", + + "generate_loopy", + "Target", "PyOpenCLTarget", +) diff --git a/pytato/array.py b/pytato/array.py index 42041060ff83432a4867c81172b529326741005f..e9bea177a3b664d41670818de7612d9743782a35 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -40,24 +40,31 @@ __doc__ = """ Array Interface --------------- -.. autoclass :: Namespace -.. autoclass :: Array -.. autoclass :: Tag -.. autoclass :: UniqueTag -.. autoclass :: DictOfNamedArrays +.. autoclass:: Namespace +.. autoclass:: Array +.. autoclass:: Tag +.. autoclass:: UniqueTag +.. autoclass:: DictOfNamedArrays Supporting Functionality ------------------------ -.. autoclass :: DottedName +.. autoclass:: DottedName .. currentmodule:: pytato.array +Concrete Array Data +------------------- +.. autoclass:: DataInterface + Pre-Defined Tags ---------------- .. autoclass:: ImplementAs .. autoclass:: CountNamed +.. autoclass:: ImplStored +.. autoclass:: ImplInlined +.. autoclass:: ImplDefault Built-in Expression Nodes ------------------------- @@ -85,15 +92,21 @@ Node constructors such as :class:`Placeholder.__init__` and # }}} +from functools import partialmethod +from numbers import Number +import operator +from dataclasses import dataclass +from typing import ( + Optional, ClassVar, Dict, Any, Mapping, Iterator, Tuple, Union, + FrozenSet, Protocol) + import numpy as np import pymbolic.primitives as prim +from pytools import is_single_valued, memoize_method + import pytato.scalar_expr as scalar_expr from pytato.scalar_expr import ScalarExpression -from dataclasses import dataclass -from pytools import is_single_valued -from typing import Optional, Dict, Any, Mapping, Iterator, Tuple, Union, FrozenSet - # {{{ dotted name @@ -138,7 +151,7 @@ class DottedName: # {{{ namespace -class Namespace: +class Namespace(Mapping[str, "Array"]): # Possible future extension: .parent attribute r""" Represents a mapping from :term:`identifier` strings to @@ -149,14 +162,16 @@ class Namespace: .. automethod:: __contains__ .. automethod:: __getitem__ .. automethod:: __iter__ + .. automethod:: __len__ .. automethod:: assign + .. automethod:: copy .. automethod:: ref """ def __init__(self) -> None: - self._symbol_table: Dict[str, Optional[Array]] = {} + self._symbol_table: Dict[str, Array] = {} - def __contains__(self, name: str) -> bool: + def __contains__(self, name: object) -> bool: return name in self._symbol_table def __getitem__(self, name: str) -> Array: @@ -168,13 +183,17 @@ class Namespace: def __iter__(self) -> Iterator[str]: return iter(self._symbol_table) - def assign(self, name: str, - value: Optional[Array]) -> str: + def __len__(self) -> int: + return len(self._symbol_table) + + def copy(self) -> Namespace: + raise NotImplementedError + + def assign(self, name: str, value: Array) -> str: """Declare a new array. :param name: a Python identifier - :param value: the array object, or None if the assignment is to - just reserve a name + :param value: the array object :returns: *name* """ @@ -239,6 +258,7 @@ class UniqueTag(Tag): Only one instance of this type of tag may be assigned to a single tagged object. """ + pass TagsType = FrozenSet[Tag] @@ -285,12 +305,10 @@ def normalize_shape( :param ns: if a namespace is given, extra checks are performed to ensure that expressions are well-defined. """ - from pytato.scalar_expr import parse - def normalize_shape_component( s: ConvertibleToShapeComponent) -> ScalarExpression: if isinstance(s, str): - s = parse(s) + s = scalar_expr.parse(s) if isinstance(s, int): if s < 0: @@ -303,7 +321,7 @@ def normalize_shape( return s if isinstance(shape, str): - shape = parse(shape) + shape = scalar_expr.parse(shape) from numbers import Number if isinstance(shape, (Number, prim.Expression)): @@ -375,15 +393,19 @@ class Array: .. attribute:: ndim """ + mapper_method: ClassVar[str] + # A tuple of field names. Fields must be equality comparable and + # hashable. Dicts of hashable keys and values are also permitted. + fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags") - def __init__(self, namespace: Namespace, - tags: Optional[TagsType] = None): + def __init__(self, + namespace: Namespace, + tags: Optional[TagsType] = None): if tags is None: tags = frozenset() self.namespace = namespace self.tags = tags - self.dtype: np.dtype = np.float64 # FIXME def copy(self, **kwargs: Any) -> Array: raise NotImplementedError @@ -392,6 +414,10 @@ class Array: def shape(self) -> ShapeType: raise NotImplementedError + @property + def dtype(self) -> np.dtype: + raise NotImplementedError + def named(self, name: str) -> Array: return self.namespace.ref(self.namespace.assign(name, self)) @@ -409,17 +435,98 @@ class Array: return self.copy(tags=self.tags | frozenset([tag])) def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array: - new_tags = tuple( - t for t in self.tags - if t != tag) + new_tags = tuple(t for t in self.tags if t != tag) if verify_existence and len(new_tags) == len(self.tags): raise ValueError(f"tag '{tag}' was not present") return self.copy(tags=new_tags) - # TODO: - # - codegen interface + @memoize_method + def __hash__(self) -> int: + attrs = [] + for field in self.fields: + attr = getattr(self, field) + if isinstance(attr, dict): + attr = frozenset(attr.items()) + attrs.append(attr) + return hash(tuple(attrs)) + + def __eq__(self, other: Any) -> bool: + if self is other: + return True + return ( + isinstance(other, type(self)) + and self.namespace is other.namespace + and all( + getattr(self, field) == getattr(other, field) + for field in self.fields)) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def _binary_op(self, + op: Any, + other: Union[Array, Number], + reverse: bool = False) -> Array: + if isinstance(other, Number): + # TODO + raise NotImplementedError + else: + if self.shape != other.shape: + raise NotImplementedError("broadcasting") + + dtype = np.result_type(self.dtype, other.dtype) + + # FIXME: If either *self* or *other* is an IndexLambda, its expression + # could be folded into the output, producing a fused result. + if self.shape == (): + expr = op(prim.Variable("_in0"), prim.Variable("_in1")) + else: + indices = tuple( + prim.Variable(f"_{i}") for i in range(self.ndim)) + expr = op( + prim.Variable("_in0")[indices], + prim.Variable("_in1")[indices]) + + first, second = self, other + if reverse: + first, second = second, first + + bindings = dict(_in0=first, _in1=second) + + return IndexLambda(self.namespace, + expr, + shape=self.shape, + dtype=dtype, + bindings=bindings) + + __mul__ = partialmethod(_binary_op, operator.mul) + __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) + + +class _SuppliedShapeAndDtypeMixin(object): + """A mixin class for when an array must store its own *shape* and *dtype*, + rather than when it can derive them easily from inputs. + """ + + def __init__(self, + namespace: Namespace, + shape: ShapeType, + dtype: np.dtype, + **kwargs: Any): + # https://github.com/python/mypy/issues/5887 + super().__init__(namespace, **kwargs) # type: ignore + self._shape = shape + self._dtype = dtype + + @property + def shape(self) -> ShapeType: + return self._shape + + @property + def dtype(self) -> np.dtype: + return self._dtype # }}} @@ -512,7 +619,7 @@ class DictOfNamedArrays(Mapping[str, Array]): # {{{ index lambda -class IndexLambda(Array): +class IndexLambda(_SuppliedShapeAndDtypeMixin, Array): """ .. attribute:: expr @@ -537,33 +644,26 @@ class IndexLambda(Array): .. automethod:: is_reference """ - - # TODO: write make_index_lambda() that does dtype inference - - def __init__( - self, namespace: Namespace, expr: prim.Expression, - shape: ShapeType, dtype: np.dtype, + fields = Array.fields + ("expr", "bindings") + mapper_method = "map_index_lambda" + + def __init__(self, + namespace: Namespace, + expr: prim.Expression, + shape: ShapeType, + dtype: np.dtype, bindings: Optional[Dict[str, Array]] = None, tags: Optional[TagsType] = None): if bindings is None: bindings = {} - super().__init__(namespace, tags=tags) + super().__init__(namespace, shape=shape, dtype=dtype, tags=tags) - self._shape = shape - self._dtype = dtype self.expr = expr self.bindings = bindings - @property - def shape(self) -> ShapeType: - return self._shape - - @property - def dtype(self) -> np.dtype: - return self._dtype - + @memoize_method def is_reference(self) -> bool: # FIXME: Do we want a specific 'reference' node to make all this # checking unnecessary? @@ -617,6 +717,22 @@ class Reshape(Array): # {{{ data wrapper +class DataInterface(Protocol): + """A protocol specifying the minimal interface requirements for concrete + array data supported by :class:`DataWrapper`. + + See :class:`typing.Protocol` for more information about protocols. + + Code generation targets may impose additional restrictions on the kinds of + concrete array data they support. + + .. attribute:: shape + .. attribute:: dtype + """ + shape: ShapeType + dtype: np.dtype + + class DataWrapper(Array): # TODO: Name? """ @@ -635,15 +751,15 @@ class DataWrapper(Array): this array may not be updated in-place. """ - # TODO: not really Any data - def __init__(self, namespace: Namespace, data: Any, - tags: Optional[TagsType] = None): - super().__init__(namespace, tags) - + def __init__(self, + namespace: Namespace, + data: DataInterface, + tags: Optional[TagsType] = None): + super().__init__(namespace, tags=tags) self.data = data @property - def shape(self) -> Any: # FIXME + def shape(self) -> ShapeType: return self.data.shape @property @@ -655,8 +771,8 @@ class DataWrapper(Array): # {{{ placeholder -class Placeholder(Array): - """ +class Placeholder(_SuppliedShapeAndDtypeMixin, Array): + r""" A named placeholder for an array whose concrete value is supplied by the user during evaluation. @@ -664,32 +780,41 @@ class Placeholder(Array): The name by which a value is supplied for the placeholder once computation begins. + The name is also implicitly :meth:`~Namespace.assign`\ ed + in the :class:`Namespace`. .. note:: - :attr:`name` is not a :term:`namespace name`. In fact, - it is prohibited from being one. (This has to be the case: Suppose a - :class:`Placeholder` is :meth:`~Array.tagged`, would the namespace name - refer to the tagged or the untagged version?) + Creating multiple instances of a :class:`Placeholder` with the same name + and within the same :class:`Namespace` is not allowed. """ + mapper_method = "map_placeholder" + fields = Array.fields + ("name",) + + def __init__(self, + namespace: Namespace, + name: str, + shape: ShapeType, + dtype: np.dtype, + tags: Optional[TagsType] = None): + if name is None: + raise ValueError("Must have explicit name") - def __init__(self, namespace: Namespace, - name: str, shape: ShapeType, - tags: Optional[TagsType] = None): - - # Reserve the name, prevent others from using it. - namespace.assign(name, None) + # Publish our name to the namespace + namespace.assign(name, self) - super().__init__(namespace=namespace, tags=tags) + super().__init__(namespace=namespace, + shape=shape, + dtype=dtype, + tags=tags) self.name = name - self._shape = shape - @property - def shape(self) -> ShapeType: - # Matt added this to make Pylint happy. - # Not tied to this, open for discussion about how to implement this. - return self._shape + def tagged(self, tag: Tag) -> Array: + raise ValueError("Cannot modify tags") + + def without_tag(self, tag: Tag, verify_existence: bool = True) -> Array: + raise ValueError("Cannot modify tags") # }}} @@ -705,13 +830,13 @@ class LoopyFunction(DictOfNamedArrays): name. """ + # }}} # {{{ end-user-facing -def make_dict_of_named_arrays( - data: Dict[str, Array]) -> DictOfNamedArrays: +def make_dict_of_named_arrays(data: Dict[str, Array]) -> DictOfNamedArrays: """Make a :class:`DictOfNamedArrays` object and ensure that all arrays share the same namespace. @@ -724,22 +849,27 @@ def make_dict_of_named_arrays( def make_placeholder(namespace: Namespace, - name: str, - shape: ConvertibleToShape, - tags: Optional[TagsType] = None - ) -> Placeholder: + name: str, + shape: ConvertibleToShape, + dtype: np.dtype, + tags: Optional[TagsType] = None) -> Placeholder: """Make a :class:`Placeholder` object. - :param namespace: namespace of the placeholder array - :param shape: shape of the placeholder array - :param tags: implementation tags + :param namespace: namespace of the placeholder array + :param name: name of the placeholder array + :param shape: shape of the placeholder array + :param dtype: dtype of the placeholder array + :param tags: implementation tags """ + if name is None: + raise ValueError("Placeholder instances must have a name") + if not str.isidentifier(name): - raise ValueError(f"{name} is not a Python identifier") + raise ValueError(f"'{name}' is not a Python identifier") shape = normalize_shape(shape, namespace) - return Placeholder(namespace, name, shape, tags) + return Placeholder(namespace, name, shape, dtype, tags) # }}} diff --git a/pytato/codegen.py b/pytato/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7a7cb36d66ca0e98917f0dfee4cef3b7618688 --- /dev/null +++ b/pytato/codegen.py @@ -0,0 +1,463 @@ +from __future__ import annotations + +__copyright__ = """Copyright (C) 2020 Matt Wala""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import dataclasses +from typing import (Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set) + +import islpy as isl +import loopy as lp +import pymbolic.primitives as prim +import pytools + +from pytato.array import ( + Array, DictOfNamedArrays, Placeholder, ShapeType, IndexLambda) +from pytato.program import BoundProgram +from pytato.target import Target, PyOpenCLTarget +import pytato.scalar_expr as scalar_expr +from pytato.scalar_expr import ScalarExpression +import pytato.transform + + +__doc__ = """ +Generating Code +--------------- + +.. currentmodule:: pytato + +.. autofunction:: generate_loopy + +Code Generation Internals +------------------------- + +.. currentmodule:: pytato.codegen + +.. autoclass:: LoopyExpressionContext +.. autoclass:: ImplementedResult +.. autoclass:: StoredResult +.. autoclass:: InlinedResult +.. autoclass:: SubstitutionRuleResult + +.. autoclass:: CodeGenState +.. autoclass:: CodeGenMapper + +.. autoclass:: InlinedExpressionGenMapper + +.. autofunction:: domain_for_shape +.. autofunction:: add_output + +""" + + +# {{{ generated array expressions + +# SymbolicIndex and ShapeType are semantically distinct but identical at the +# type level. +SymbolicIndex = ShapeType +ReductionBounds = Dict[str, Tuple[ScalarExpression, ScalarExpression]] + + +@dataclasses.dataclass(init=True, repr=False, eq=False) +class LoopyExpressionContext(object): + """Mutable state used while generating :mod:`loopy` expressions. + Wraps :class:`CodeGenState` with more expression-specific information. + + This data is passed through :class:`InlinedExpressionGenMapper` via arguments, + and is also used by :meth:`ImplementedResult.to_loopy_expression` to + retrieve contextual data. + + .. attribute:: state + + The :class:`CodeGenState`. + + .. attribute:: local_namespace + + A (read-only) local name mapping used for name lookup when generating + code. + + .. attribute:: depends_on + + The set of statement IDs that need to be included in + :attr:`loopy.InstructionBase.depends_on`. + + .. attribute:: reduction_bounds + + A mapping from inames to reduction bounds in the expression. + + .. automethod:: update_depends_on + .. automethod:: lookup + + """ + state: CodeGenState + _depends_on: FrozenSet[str] = \ + dataclasses.field(default_factory=frozenset) + local_namespace: Mapping[str, Array] = \ + dataclasses.field(default_factory=dict) + reduction_bounds: ReductionBounds = \ + dataclasses.field(default_factory=dict) + + def lookup(self, name: str) -> Array: + try: + return self.local_namespace[name] + except KeyError: + return self.state.namespace[name] + + @property + def depends_on(self) -> FrozenSet[str]: + return self._depends_on + + def update_depends_on(self, other: FrozenSet[str]) -> None: + self._depends_on = self._depends_on | other + + +class ImplementedResult(object): + """Generated code for a node in the computation graph (i.e., an array + expression). + + .. attribute:: array + + The :class:`pytato.Array` associated with this code. + + .. automethod:: to_loopy_expression + """ + def __init__(self, array: Array): + self.array = array + + def to_loopy_expression(self, indices: SymbolicIndex, + expr_context: LoopyExpressionContext) -> ScalarExpression: + """Return a :mod:`loopy` expression for this result.""" + raise NotImplementedError + + +class StoredResult(ImplementedResult): + """An array expression generated as a :mod:`loopy` array. + + See also: :class:`pytato.array.ImplStored`. + """ + def __init__(self, name: str, array: Array): + super().__init__(array) + self.name = name + + # TODO: Handle dependencies. + def to_loopy_expression(self, indices: SymbolicIndex, + expr_context: LoopyExpressionContext) -> ScalarExpression: + if indices == (): + return prim.Variable(self.name) + else: + return prim.Variable(self.name)[indices] + + +class InlinedResult(ImplementedResult): + """An array expression generated as a :mod:`loopy` expression containing inlined + sub-expressions. + + See also: :class:`pytato.array.ImplInlined`. + """ + def __init__(self, expr: ScalarExpression, array: Array): + super().__init__(array) + self.expr = expr + + # TODO: Handle dependencies and reduction domains. + def to_loopy_expression(self, indices: SymbolicIndex, + expr_context: LoopyExpressionContext) -> ScalarExpression: + return scalar_expr.substitute( + self.expr, + {f"_{d}": i for d, i in zip(range(self.array.ndim), indices)}) + + +class SubstitutionRuleResult(ImplementedResult): + # TODO: implement + pass + +# }}} + + +# {{{ codegen + +@dataclasses.dataclass(init=True, repr=False, eq=False) +class CodeGenState: + """A container for data kept by :class:`CodeGenMapper`. + + .. attribute:: namespace + + The (global) namespace + + .. attribute:: kernel + + The partial :class:`loopy.LoopKernel` being built. + + .. attribute:: results + + A mapping from :class:`pytato.Array` instances to + instances of :class:`ImplementedResult`. + + .. attribute:: var_name_gen + .. attribute:: insn_id_gen + + .. automethod:: update_kernel + """ + namespace: Mapping[str, Array] + _kernel: lp.LoopKernel + results: Dict[Array, ImplementedResult] + + var_name_gen: pytools.UniqueNameGenerator = dataclasses.field(init=False) + insn_id_gen: pytools.UniqueNameGenerator = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.var_name_gen = self._kernel.get_var_name_generator() + self.insn_id_gen = self._kernel.get_instruction_id_generator() + + @property + def kernel(self) -> lp.LoopKernel: + return self._kernel + + def update_kernel(self, kernel: lp.LoopKernel) -> None: + self._kernel = kernel + + +class CodeGenMapper(pytato.transform.Mapper): + """A mapper for generating code for nodes in the computation graph. + """ + exprgen_mapper: InlinedExpressionGenMapper + + def __init__(self) -> None: + self.exprgen_mapper = InlinedExpressionGenMapper(self) + + def map_placeholder(self, expr: Placeholder, + state: CodeGenState) -> ImplementedResult: + if expr in state.results: + return state.results[expr] + + arg = lp.GlobalArg(expr.name, + shape=expr.shape, + dtype=expr.dtype, + order="C") + kernel = state.kernel.copy(args=state.kernel.args + [arg]) + state.update_kernel(kernel) + + result = StoredResult(expr.name, expr) + state.results[expr] = result + return result + + def map_index_lambda(self, expr: IndexLambda, + state: CodeGenState) -> ImplementedResult: + if expr in state.results: + return state.results[expr] + + # TODO: Respect tags. + + expr_context = LoopyExpressionContext(state, + local_namespace=expr.bindings) + loopy_expr = self.exprgen_mapper(expr.expr, expr_context) + + result = InlinedResult(loopy_expr, expr) + state.results[expr] = result + return result + +# }}} + + +# {{{ inlined expression gen mapper + +class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): + """A mapper for generating :mod:`loopy` expressions with inlined + sub-expressions. + + The inputs to this mapper are scalar expression as found in + :class:`pytato.array.IndexLambda`, or expressions that are + compatible (e.g., shape expressions). + + The outputs of this mapper are scalar expressions suitable for wrapping in + :class:`InlinedResult`. + """ + codegen_mapper: CodeGenMapper + + def __init__(self, codegen_mapper: CodeGenMapper): + self.codegen_mapper = codegen_mapper + + def __call__(self, expr: ScalarExpression, + expr_context: LoopyExpressionContext) -> ScalarExpression: + return self.rec(expr, expr_context) + + def map_subscript(self, expr: prim.Subscript, + expr_context: LoopyExpressionContext) -> ScalarExpression: + assert isinstance(expr.aggregate, prim.Variable) + result: ImplementedResult = self.codegen_mapper( + expr_context.lookup(expr.aggregate.name), expr_context.state) + return result.to_loopy_expression(expr.index, expr_context) + + # TODO: map_reduction() + + def map_variable(self, expr: prim.Variable, + expr_context: LoopyExpressionContext) -> ScalarExpression: + result: ImplementedResult = self.codegen_mapper( + expr_context.lookup(expr.name), + expr_context.state) + return result.to_loopy_expression((), expr_context) + +# }}} + + +# {{{ utils + +def domain_for_shape( + dim_names: Tuple[str, ...], shape: ShapeType) -> isl.BasicSet: + """Create an :class:`islpy.BasicSet` that expresses an appropriate index domain + for an array of (potentially symbolic) shape *shape*. + + :param dim_names: A tuple of strings, the names of the axes. These become set + dimensions in the returned domain. + + :param shape: A tuple of constant or quasi-affine :mod:`pymbolic` + expressions. The variables in these expressions become parameter + dimensions in the returned set. Must have the same length as + *dim_names*. + """ + assert len(dim_names) == len(shape) + + # Collect parameters. + param_names_set: Set[str] = set() + for sdep in map(scalar_expr.get_dependencies, shape): + param_names_set |= sdep + + set_names = sorted(dim_names) + param_names = sorted(param_names_set) + + # Build domain. + dom = isl.BasicSet.universe( + isl.Space.create_from_names(isl.DEFAULT_CONTEXT, + set=set_names, + params=param_names)) + + # Add constraints. + from loopy.symbolic import aff_from_expr + affs = isl.affs_from_space(dom.space) + + for iname, dim in zip(dim_names, shape): + dom &= affs[0].le_set(affs[iname]) + dom &= affs[iname].lt_set(aff_from_expr(dom.space, dim)) + + dom, = dom.get_basic_sets() + + return dom + + +def add_output(name: str, expr: Array, state: CodeGenState, + mapper: CodeGenMapper) -> None: + """Add an output argument to the kernel. + """ + # FIXE: Scalar outputs are not supported yet. + assert expr.shape != () + + result = mapper(expr, state) + + inames = tuple( + state.var_name_gen(f"{name}_dim{d}") + for d in range(expr.ndim)) + domain = domain_for_shape(inames, expr.shape) + + arg = lp.GlobalArg(name, + shape=expr.shape, + dtype=expr.dtype, + order="C", + is_output_only=True) + + indices = tuple(prim.Variable(iname) for iname in inames) + expr_context = LoopyExpressionContext(state) + copy_expr = result.to_loopy_expression(indices, expr_context) + + # TODO: Contextual data not supported yet. + assert not expr_context.reduction_bounds + assert not expr_context.depends_on + + from loopy.kernel.instruction import make_assignment + insn = make_assignment((prim.Variable(name)[indices],), + copy_expr, + id=state.insn_id_gen(f"{name}_copy"), + within_inames=frozenset(inames), + depends_on=expr_context.depends_on) + + kernel = state.kernel + kernel = kernel.copy(args=kernel.args + [arg], + instructions=kernel.instructions + [insn], + domains=kernel.domains + [domain]) + state.update_kernel(kernel) + +# }}} + + +def generate_loopy(result: Union[Array, DictOfNamedArrays], + target: Optional[Target] = None, + options: Optional[lp.Options] = None) -> BoundProgram: + r"""Code generation entry point. + + :param result: Outputs of the computation. + :param target: Code generation target. + :param options: Code generation options for the kernel. + :returns: A wrapped generated :mod:`loopy` kernel + """ + # {{{ get namespace and outputs + + outputs: DictOfNamedArrays + + if isinstance(result, Array): + outputs = DictOfNamedArrays({"_pt_out": result}) + namespace = outputs.namespace + else: + assert isinstance(result, DictOfNamedArrays) + outputs = result + + namespace = outputs.namespace + del result + + # }}} + + if target is None: + target = PyOpenCLTarget() + + # Set up codegen state. + kernel = lp.make_kernel("{:}", [], + target=target.get_loopy_target(), + options=options, + lang_version=lp.MOST_RECENT_LANGUAGE_VERSION) + + state = CodeGenState(namespace=namespace, + _kernel=kernel, + results=dict()) + + # Reserve names of input and output arguments. + for val in namespace.values(): + if isinstance(val, Placeholder): + state.var_name_gen.add_name(val.name) + state.var_name_gen.add_names(outputs) + + # Generate code for graph nodes. + mapper = CodeGenMapper() + for name, val in namespace.items(): + _ = mapper(val, state) + + # Generate code for outputs. + for name, expr in outputs.items(): + add_output(name, expr, state, mapper) + + return target.bind_program(program=state.kernel, bound_arguments=dict()) diff --git a/pytato/mapper.py b/pytato/mapper.py deleted file mode 100644 index afeec20a61296a76f0b5622ec6b1b2513a73d179..0000000000000000000000000000000000000000 --- a/pytato/mapper.py +++ /dev/null @@ -1,56 +0,0 @@ -__copyright__ = """ -Copyright (C) 2020 Andreas Kloeckner -Copyright (C) 2020 Matt Wala -Copyright (C) 2020 Xiaoyu Wei -""" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -from typing import Any - - -class Mapper: - pass - - -class IdentityMapper: - """ - Graph transformations subclass this - """ - pass - - -class StringifyMapper: - pass - - -class ToLoopyMapper: - pass - - # {{{ - - def _stringify(self) -> str: - pass - - def _generate_code(self) -> Any: - pass - - # }}} diff --git a/pytato/program.py b/pytato/program.py new file mode 100644 index 0000000000000000000000000000000000000000..ded19094ad9db8fff75ce722b7a5aacc37777e32 --- /dev/null +++ b/pytato/program.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +__copyright__ = """Copyright (C) 2020 Matt Wala""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +__doc__ = """ +.. currentmodule:: pytato.program + +Generated Executable Programs +----------------------------- + +.. autoclass:: BoundProgram +.. autoclass:: BoundPyOpenCLProgram +""" + +from dataclasses import dataclass +import typing +from typing import Any, Mapping, Optional + + +if typing.TYPE_CHECKING: + # Imports skipped for efficiency. FIXME: Neither of these work as type + # stubs are not present. Types are here only as documentation. + import pyopencl as cl + import loopy as lp + # Imports skipped to avoid circular dependencies. + import pytato.target + + +@dataclass(init=True, repr=False, eq=False) +class BoundProgram: + """A wrapper around a :mod:`loopy` kernel for execution. + + .. attribute:: program + + The underlying :class:`loopy.LoopKernel`. + + .. attribute:: target + + The code generation target. + + .. attribute:: bound_arguments + + A map from names to pre-bound kernel arguments. + + .. automethod:: __call__ + """ + + program: "lp.LoopKernel" + bound_arguments: Mapping[str, Any] + target: "pytato.target.Target" + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + +@dataclass(init=True, repr=False, eq=False) +class BoundPyOpenCLProgram(BoundProgram): + """A wrapper around a :mod:`loopy` kernel for execution with :mod:`pyopencl`. + + .. attribute:: queue + + A :mod:`pyopencl` command queue. + + .. automethod:: __call__ + """ + queue: Optional["cl.CommandQueue"] + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Convenience function for launching a :mod:`pyopencl` computation.""" + if not self.queue: + raise ValueError("queue must be specified") + + updated_kwargs = dict(self.bound_arguments) + updated_kwargs.update(kwargs) + return self.program(self.queue, *args, **updated_kwargs) + +# vim: foldmethod=marker diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index c6759845b7b83a53c4e253e481bcdf96f7a6f031..a09e9db9eddb5563c36ff1348982301039bb0f89 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -24,12 +24,38 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from pymbolic.mapper import WalkMapper as WalkMapperBase +from numbers import Number +from typing import Any, Union, Mapping, FrozenSet + +from pymbolic.mapper import (WalkMapper as WalkMapperBase, IdentityMapper as + IdentityMapperBase) +from pymbolic.mapper.substitutor import (SubstitutionMapper as + SubstitutionMapperBase) +from pymbolic.mapper.dependency import (DependencyMapper as + DependencyMapperBase) import pymbolic.primitives as prim -from numbers import Number -from typing import Union +__doc__ = """ +.. currentmodule:: pytato.scalar_expr + +Scalar Expressions +------------------ +.. data:: ScalarExpression + + A :class:`type` for scalar-valued symbolic expressions. Expressions are + composable and manipulable via :mod:`pymbolic`. + + Concretely, this is an alias for + ``Union[Number, pymbolic.primitives.Expression]``. + +.. autofunction:: parse +.. autofunction:: get_dependencies +.. autofunction:: substitute + +""" + +# {{{ scalar expressions ScalarExpression = Union[Number, prim.Expression] @@ -38,9 +64,52 @@ def parse(s: str) -> ScalarExpression: from pymbolic.parser import Parser return Parser()(s) +# }}} + + +# {{{ mapper classes class WalkMapper(WalkMapperBase): pass +class IdentityMapper(IdentityMapperBase): + pass + + +class SubstitutionMapper(SubstitutionMapperBase): + pass + + +class DependencyMapper(DependencyMapperBase): + pass + +# }}} + + +# {{{ mapper frontends + +def get_dependencies(expression: Any) -> FrozenSet[str]: + """Return the set of variable names in an expression. + + :param expression: A scalar expression, or an expression derived from such + (e.g., a tuple of scalar expressions) + """ + mapper = DependencyMapper(composite_leaves=False) + return frozenset(dep.name for dep in mapper(expression)) + + +def substitute(expression: Any, variable_assigments: Mapping[str, Any]) -> Any: + """Perform variable substitution in an expression. + + :param expression: A scalar expression, or an expression derived from such + (e.g., a tuple of scalar expressions) + :param variable_assigments: A mapping from variable names to substitutions + """ + from pymbolic.mapper.substitutor import make_subst_func + return SubstitutionMapper(make_subst_func(variable_assigments))(expression) + +# }}} + + # vim: foldmethod=marker diff --git a/pytato/target.py b/pytato/target.py new file mode 100644 index 0000000000000000000000000000000000000000..0242c0d07475fe4dad1e6f394c00a660ad170358 --- /dev/null +++ b/pytato/target.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +__copyright__ = """Copyright (C) 2020 Matt Wala""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +__doc__ = """ +.. currentmodule:: pytato.target + +Code Generation Targets +----------------------- + +.. autoclass:: Target +.. autoclass:: PyOpenCLTarget +""" + +import typing +from typing import Any, Mapping, Optional + +from pytato.program import BoundProgram, BoundPyOpenCLProgram + + +if typing.TYPE_CHECKING: + # Skip imports for efficiency. FIXME: Neither of these work as type stubs + # are not present. Types are here only as documentation. + import pyopencl as cl + import loopy as lp + + +class Target: + """An abstract code generation target. + + .. automethod:: get_loopy_target + .. automethod:: bind_program + """ + + def get_loopy_target(self) -> "lp.TargetBase": + """Return the corresponding :mod:`loopy` target.""" + raise NotImplementedError + + def bind_program(self, program: "lp.LoopKernel", + bound_arguments: Mapping[str, Any]) -> BoundProgram: + """Create a :class:`pytato.program.BoundProgram` for this code generation target. + + :param program: the :mod:`loopy` kernel + :param bound_arguments: a mapping from argument names to outputs + """ + raise NotImplementedError + + +class PyOpenCLTarget(Target): + """A :mod:`pyopencl` code generation target. + + .. attribute:: queue + + The :mod:`pyopencl` command queue, or *None*. + """ + + def __init__(self, queue: Optional["cl.CommandQueue"] = None): + self.queue = queue + + def get_loopy_target(self) -> "lp.PyOpenCLTarget": + import loopy as lp + device = None + if self.queue is not None: + device = self.queue.device + return lp.PyOpenCLTarget(device) + + def bind_program(self, program: "lp.LoopKernel", + bound_arguments: Mapping[str, Any]) -> BoundProgram: + return BoundPyOpenCLProgram(program=program, + queue=self.queue, + bound_arguments=bound_arguments, + target=self) + +# vim: foldmethod=marker diff --git a/pytato/transform.py b/pytato/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..13546653188a6da64a5f65b1ef3ed68832e115b3 --- /dev/null +++ b/pytato/transform.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +__copyright__ = """ +Copyright (C) 2020 Matt Wala +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Any, Callable + +from pytato.array import Array + +__doc__ = """ +.. currentmodule:: pytato.transform + +Transforming Computations +------------------------- + +.. autoclass:: Mapper + +""" + + +# {{{ mapper classes + +class UnsupportedArrayError(ValueError): + pass + + +class Mapper: + def handle_unsupported_array(self, expr: Array, *args: Any, + **kwargs: Any) -> Any: + """Mapper method that is invoked for + :class:`pytato.Array` subclasses for which a mapper + method does not exist in this mapper. + """ + raise UnsupportedArrayError("%s cannot handle expressions of type %s" + % (type(self).__name__, type(expr))) + + def map_foreign(self, expr: Any, *args: Any, **kwargs: Any) -> Any: + raise ValueError("%s encountered invalid foreign object: %s" + % (type(self).__name__, repr(expr))) + + def __call__(self, expr: Array, *args: Any, **kwargs: Any) -> Any: + method: Callable[..., Array] + + try: + method = getattr(self, expr.mapper_method) + except AttributeError: + if isinstance(expr, Array): + return self.handle_unsupported_array(expr, *args, **kwargs) + else: + return self.map_foreign(expr, *args, **kwargs) + + return method(expr, *args, **kwargs) + + rec = __call__ + +# }}} + + +# vim: foldmethod=marker diff --git a/requirements.txt b/requirements.txt index 22617c7cad00812871d27675ccf9fbf0e92e3afc..89007665330e9389aeb9c683b778215e3135b1b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ +git+https://github.com/inducer/pytools.git git+https://github.com/inducer/loopy.git diff --git a/setup.cfg b/setup.cfg index b975cd3a4f878d7de702bd97d6a86d6e65b9ab4b..b0d0924987b6cad885c759657f6fbc171aa59c4f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,23 +2,23 @@ ignore = E126,E127,E128,E123,E226,E241,E242,E265,N802,W503,E402,N814,N817,W504 max-line-length=85 -[mypy-pytato.scalar_expr] +[mypy-pytato.transform] disallow_subclassing_any = False -[mypy-pymbolic] -ignore_missing_imports = True +[mypy-pytato.scalar_expr] +disallow_subclassing_any = False -[mypy-pymbolic.primitives] +[mypy-islpy] ignore_missing_imports = True -[mypy-pymbolic.mapper] +[mypy-loopy.*] ignore_missing_imports = True -[mypy-pymbolic.parser] +[mypy-numpy] ignore_missing_imports = True -[mypy-pytools] +[mypy-pymbolic.*] ignore_missing_imports = True -[mypy-numpy] +[mypy-pyopencl] ignore_missing_imports = True diff --git a/test/test_codegen.py b/test/test_codegen.py new file mode 100755 index 0000000000000000000000000000000000000000..0f26bc0c1881a2b5e2b45844004fa9ed545bfafd --- /dev/null +++ b/test/test_codegen.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python + +__copyright__ = "Copyright (C) 2020 Andreas Kloeckner" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys + +import loopy as lp +import numpy as np +import pyopencl as cl +import pyopencl.array as cl_array # noqa +import pyopencl.cltypes as cltypes # noqa +import pyopencl.tools as cl_tools # noqa +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) +import pytest # noqa + +import pytato as pt + + +def test_basic_codegen(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + namespace = pt.Namespace() + x = pt.Placeholder(namespace, "x", (5,), np.int) + prog = pt.generate_loopy(x * x, target=pt.PyOpenCLTarget(queue)) + x_in = np.array([1, 2, 3, 4, 5]) + _, (out,) = prog(x=x_in) + assert (out == x_in * x_in).all() + + +def test_codegen_with_DictOfNamedArrays(ctx_factory): # noqa + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + namespace = pt.Namespace() + x = pt.Placeholder(namespace, "x", (5,), np.int) + y = pt.Placeholder(namespace, "y", (5,), np.int) + x_in = np.array([1, 2, 3, 4, 5]) + y_in = np.array([6, 7, 8, 9, 10]) + + result = pt.DictOfNamedArrays(dict(x_out=x, y_out=y)) + + # Without return_dict. + prog = pt.generate_loopy(result, target=pt.PyOpenCLTarget(queue)) + _, (x_out, y_out) = prog(x=x_in, y=y_in) + assert (x_out == x_in).all() + assert (y_out == y_in).all() + + # With return_dict. + prog = pt.generate_loopy(result, + target=pt.PyOpenCLTarget(queue), + options=lp.Options(return_dict=True)) + + _, outputs = prog(x=x_in, y=y_in) + assert (outputs["x_out"] == x_in).all() + assert (outputs["y_out"] == y_in).all() + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: filetype=pyopencl:fdm=marker