From f0dd15a1e2e82fe4b92ce8c554e72dfd665e0240 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 12 Feb 2025 10:23:20 -0600 Subject: [PATCH] Fix ispc streaming store generation --- loopy/target/ispc.py | 151 +++++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 71 deletions(-) diff --git a/loopy/target/ispc.py b/loopy/target/ispc.py index 096cb2cd..34a88328 100644 --- a/loopy/target/ispc.py +++ b/loopy/target/ispc.py @@ -29,16 +29,26 @@ from functools import reduce from typing import TYPE_CHECKING, Iterable, Sequence, cast import numpy as np +from typing_extensions import Never import pymbolic.primitives as p from cgen import Collection, Const, Declarator, Generable from pymbolic import var from pymbolic.mapper.stringifier import PREC_NONE +from pymbolic.mapper.substitutor import make_subst_func from pytools import memoize_method from loopy.diagnostic import LoopyError -from loopy.kernel.data import AddressSpace, ArrayArg, TemporaryVariable -from loopy.symbolic import CombineMapper, Literal +from loopy.kernel.data import AddressSpace, ArrayArg, LocalInameTag, TemporaryVariable +from loopy.symbolic import ( + CoefficientCollector, + CombineMapper, + GroupHardwareAxisIndex, + Literal, + LocalHardwareAxisIndex, + SubstitutionMapper, + flatten, +) from loopy.target.c import CFamilyASTBuilder, CFamilyTarget from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper @@ -46,28 +56,45 @@ from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper if TYPE_CHECKING: from loopy.codegen import CodeGenerationState from loopy.codegen.result import CodeGenerationResult + from loopy.kernel import LoopKernel + from loopy.kernel.instruction import Assignment from loopy.schedule import CallKernel from loopy.types import LoopyType from loopy.typing import Expression class IsVaryingMapper(CombineMapper[bool, []]): + # FIXME: Update this if/when ispc reduction support is added. + + def __init__(self, kernel: LoopKernel) -> None: + self.kernel = kernel + super().__init__() + def combine(self, values: Iterable[bool]) -> bool: return reduce(operator.or_, values, False) def map_constant(self, expr): return False - def map_group_hw_index(self, expr): - return False - - def map_local_hw_index(self, expr): - if expr.axis == 0: - return True - else: - raise LoopyError("ISPC only supports one local axis") + def map_group_hw_index(self, expr: GroupHardwareAxisIndex) -> Never: + # These only exist for a brief blip in time inside the expr-to-cexpr + # mapper. We should never see them. + raise AssertionError() + + def map_local_hw_index(self, expr: LocalHardwareAxisIndex) -> Never: + # These only exist for a brief blip in time inside the expr-to-cexpr + # mapper. We should never see them. + raise AssertionError() + + def map_variable(self, expr: p.Variable) -> bool: + iname = self.kernel.inames.get(expr.name) + if iname is not None: + ltags = iname.tags_of_type(LocalInameTag) + if ltags: + ltag, = ltags + assert ltag.axis == 0 + return True - def map_variable(self, expr): return False @@ -127,8 +154,7 @@ class ExprToISPCExprMapper(ExpressionToCExpressionMapper): return expr else: - return super().map_variable( - expr, type_context) + return super().map_variable(expr, type_context) def map_subscript(self, expr, type_context): from loopy.kernel.data import TemporaryVariable @@ -175,8 +201,8 @@ class ExprToISPCExprMapper(ExpressionToCExpressionMapper): else: actual_type = self.infer_type(expr) if actual_type != needed_type: - # FIXME: problematic: quadratic complexity - is_varying = IsVaryingMapper()(expr) + # FIXME: problematic: potential quadratic complexity + is_varying = IsVaryingMapper(self.kernel)(expr) registry = self.codegen_state.ast_builder.target.get_dtype_registry() cast = var("(" f"{'varying' if is_varying else 'uniform'} " @@ -409,7 +435,12 @@ class ISPCASTBuilder(CFamilyASTBuilder): # }}} # {{{ emit_... - def emit_assignment(self, codegen_state, insn): + + def emit_assignment( + self, + codegen_state: CodeGenerationState, + insn: Assignment + ): kernel = codegen_state.kernel ecm = codegen_state.expression_to_code_mapper @@ -442,83 +473,61 @@ class ISPCASTBuilder(CFamilyASTBuilder): from loopy.kernel.array import get_access_info from loopy.symbolic import simplify_using_aff - index_tuple = tuple( - simplify_using_aff(kernel, idx) for idx in lhs.index_tuple) - access_info = get_access_info(kernel, ary, index_tuple, - lambda expr: evaluate(expr, codegen_state.var_subst_map), - codegen_state.vectorization_info) + if not isinstance(lhs, p.Subscript): + raise LoopyError("streaming store must have a subscript as argument") from loopy.kernel.data import ArrayArg, TemporaryVariable - if not isinstance(ary, (ArrayArg, TemporaryVariable)): raise LoopyError("array type not supported in ISPC: %s" % type(ary).__name) + index_tuple = tuple( + simplify_using_aff(kernel, idx) for idx in lhs.index_tuple) + + access_info = get_access_info(kernel, ary, index_tuple, + lambda expr: cast("int", + evaluate(expr, codegen_state.var_subst_map)), + codegen_state.vectorization_info) + + l0_inames = { + iname for iname in insn.within_inames + if kernel.inames[iname].tags_of_type(LocalInameTag)} + if len(access_info.subscripts) != 1: raise LoopyError("streaming stores must have a subscript") subscript, = access_info.subscripts - from pymbolic.primitives import Sum, Variable, flattened_sum - if isinstance(subscript, Sum): - terms = subscript.children - else: - terms = (subscript.children,) - - new_terms = [] - - from loopy.kernel.data import LocalInameTag, filter_iname_tags_by_type - from loopy.symbolic import get_dependencies - - saw_l0 = False - for term in terms: - if (isinstance(term, Variable) - and kernel.iname_tags_of_type(term.name, LocalInameTag)): - tag, = kernel.iname_tags_of_type( - term.name, LocalInameTag, min_num=1, max_num=1) - if tag.axis == 0: - if saw_l0: - raise LoopyError( - "streaming store must have stride 1 in " - "local index, got: %s" % subscript) - saw_l0 = True - continue - else: - for dep in get_dependencies(term): - if dep in kernel.all_inames() and ( - filter_iname_tags_by_type(kernel.inames[dep].tags, - LocalInameTag)): - tag, = filter_iname_tags_by_type( - kernel.inames[dep].tags, LocalInameTag, 1) - if tag.axis == 0: - raise LoopyError( - "streaming store must have stride 1 in " - "local index, got: %s" % subscript) - - new_terms.append(term) - - if not saw_l0: - raise LoopyError("streaming store must have stride 1 in " - "local index, got: %s" % subscript) + if l0_inames: + l0_iname, = l0_inames + coeffs = CoefficientCollector([l0_iname])(subscript) + if coeffs[p.Variable(l0_iname)] != 1: + raise ValueError("coefficient of streaming store index " + "in l.0 variable must be 1") + + subscript = flatten( + SubstitutionMapper(make_subst_func({l0_iname: 0}))(subscript)) + del l0_iname if access_info.vector_index is not None: raise LoopyError("streaming store may not use a short-vector " "data type") - rhs_has_programindex = any( - isinstance(tag, LocalInameTag) and tag.axis == 0 - for tag in kernel.iname_tags(dep) - for dep in get_dependencies(insn.expression)) - - if not rhs_has_programindex: - rhs_code = "broadcast(%s, 0)" % rhs_code + if (l0_inames + and not IsVaryingMapper(codegen_state.kernel)(insn.expression)): + # rhs is uniform, must be cast to varying in order for streaming_store + # to perform a vector store. + registry = codegen_state.ast_builder.target.get_dtype_registry() + rhs_code = var("(varying " + f"{registry.dtype_to_ctype(lhs_dtype)}" + f") ({rhs_code})") from cgen import Statement return Statement( "streaming_store(%s + %s, %s)" % ( access_info.array_name, - ecm(flattened_sum(new_terms), PREC_NONE, "i"), + ecm(subscript, PREC_NONE, "i"), rhs_code)) # }}} -- GitLab