diff --git a/doc/tutorial.rst b/doc/tutorial.rst index b0d9cebd4620ee6db4ea0e51c1f6297b99f02908..4aeb42428488046bac9d57e9ad57a7cf29c2c0d9 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -1318,7 +1318,7 @@ The kernel translates into two OpenCL kernels. int tmp; <BLANKLINE> tmp = tmp_save_slot[16 * gid(0) + lid(0)]; - arr[(lid(0) + gid(0) * 16 + 1) % n] = tmp; + arr[(1 + lid(0) + gid(0) * 16) % n] = tmp; } Now we can execute the kernel. diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 884c26d2f5194d84066a3f6405e9dcf13ae751a0..84477749f5073b39a67acfb494b920849ef2b702 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -1,5 +1,7 @@ from __future__ import annotations +from loopy.symbolic import flatten + __copyright__ = "Copyright (C) 2012 Andreas Kloeckner" @@ -1318,7 +1320,7 @@ def get_access_info(kernel: "LoopKernel", "make_temporaries_for_offsets_and_strides " "during preprocessing.") - subscripts[dim_tag.target_axis] += (stride // vector_size)*idx + subscripts[dim_tag.target_axis] += flatten((stride // vector_size)*idx) elif isinstance(dim_tag, SeparateArrayArrayDimTag): raise AssertionError() diff --git a/loopy/statistics.py b/loopy/statistics.py index 29ea91259d0834ed1f143c713b621a12ac6c0889..2d0537fdb58798d9c98dcf976e46a3c92e431108 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -36,7 +36,7 @@ import loopy as lp from loopy.diagnostic import LoopyError, warn_with_kernel from loopy.kernel.data import AddressSpace, MultiAssignmentBase, TemporaryVariable from loopy.kernel.function_interface import CallableKernel -from loopy.symbolic import CoefficientCollector +from loopy.symbolic import CoefficientCollector, flatten from loopy.translation_unit import TranslationUnit @@ -1167,7 +1167,7 @@ def _get_lid_and_gid_strides(knl, array, index): total_iname_stride += axis_tag_stride*coeff - tag_to_stride_dict[tag] = total_iname_stride + tag_to_stride_dict[tag] = flatten(total_iname_stride) return tag_to_stride_dict diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 6727423a829e607af38e3d758765d3756c6d47cf..86e854bd2c8a1a31196d7c31cf0e33efab0285b3 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -29,7 +29,17 @@ THE SOFTWARE. import re from functools import cached_property, reduce from sys import intern -from typing import TYPE_CHECKING, AbstractSet, Any, ClassVar, Mapping, Sequence, Tuple +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + ClassVar, + Mapping, + Sequence, + Tuple, + TypeVar, + cast, +) import immutables import numpy as np @@ -39,6 +49,7 @@ import pymbolic.primitives # FIXME: also import by full name to allow sphinx to import pymbolic.primitives as p import pytools.lex from islpy import dim_type +from pymbolic import ArithmeticExpressionT from pymbolic.mapper import ( CachedCombineMapper as CombineMapperBase, CachedIdentityMapper as IdentityMapperBase, @@ -200,8 +211,14 @@ class FlattenMapper(FlattenMapperBase, IdentityMapperMixin): pass -def flatten(expr: ExpressionT) -> ExpressionT: - return FlattenMapper()(expr) +ArithmeticOrExpressionT = TypeVar( + "ArithmeticOrExpressionT", + ArithmeticExpressionT, + ExpressionT) + + +def flatten(expr: ArithmeticOrExpressionT) -> ArithmeticOrExpressionT: + return cast(ArithmeticOrExpressionT, FlattenMapper()(expr)) class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): @@ -2127,7 +2144,8 @@ def simplify_using_aff(kernel, expr): try: aff = guarded_aff_from_expr(domain.space, expr) except ExpressionToAffineConversionError: - return expr + # Accomplish at least *some* simplification + return flatten(expr) # FIXME: Deal with assumptions, too. aff = aff.gist(domain)