From b007e88ee84ab7a90335fa35de60c13cf3bb83e5 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 18 Sep 2018 14:53:38 -0400 Subject: [PATCH 1/2] add basic vectorizability checking improvments & tests --- loopy/expression.py | 153 +++++++++++++++++++++++++++++++++++++------- test/test_loopy.py | 92 ++++++++++++++++++++++++++ 2 files changed, 223 insertions(+), 22 deletions(-) diff --git a/loopy/expression.py b/loopy/expression.py index 3269bc09f..06fe3bb06 100644 --- a/loopy/expression.py +++ b/loopy/expression.py @@ -63,6 +63,30 @@ class VectorizabilityChecker(RecursiveMapper): .. attribute:: vec_iname """ + # this is a simple list of math functions from OpenCL-1.2 + # https://www.khronos.org/registry/OpenCL/sdk/1.2/docs/man/xhtml/mathFunctions.html + # this could be expanded / moved to it's own target specific VecCheck if + # necessary + functions = """acos acosh acospi asin + asinh asinpi atan atan2 + atanh atanpi atan2pi cbrt + ceil copysign cos cosh + cospi erfc erf exp + exp2 exp10 expm1 fabs + fdim floor fma fmax + fmin fmod fract frexp + hypot ilogb ldexp lgamma + lgamma_r log log2 log10 + log1p logb mad maxmag + minmag modf nan nextafter + pow pown powr remainder + remquo rint rootn round + rsqrt sin sincos sinh + sinpi sqrt tan tanh + tanpi tgamma trunc""" + + functions = [x.strip() for x in functions.split() if x.strip()] + def __init__(self, kernel, vec_iname, vec_iname_length): self.kernel = kernel self.vec_iname = vec_iname @@ -75,7 +99,7 @@ class VectorizabilityChecker(RecursiveMapper): return reduce(and_, vectorizabilities) def map_sum(self, expr): - return any(self.rec(child) for child in expr.children) + return any([self.rec(child) for child in expr.children]) map_product = map_sum @@ -84,6 +108,16 @@ class VectorizabilityChecker(RecursiveMapper): or self.rec(expr.denominator)) + map_remainder = map_quotient + + def map_floor_div(self, expr): + """ + (a) - ( ((a)<0) ? ((b)-1) : 0 ) ) / (b) + """ + a, b = expr.numerator, expr.denominator + return self.rec(a) and self.rec(a.lt(0)) and self.rec(b - 1) and \ + self.rec((a - (b - 1)) / b) and self.rec(a / b) + def map_linear_subscript(self, expr): return False @@ -93,10 +127,54 @@ class VectorizabilityChecker(RecursiveMapper): rec_pars = [ self.rec(child) for child in expr.parameters] if any(rec_pars): - raise Unvectorizable("fucntion calls cannot yet be vectorized") + if str(expr.function) not in VectorizabilityChecker.functions: + return Unvectorizable( + 'Function {} is not known to be vectorizable'.format( + str(expr.function))) + return True return False + @staticmethod + def compile_time_constants(kernel, vec_iname): + """ + Returns a dictionary of (non-vector) inames and temporary variables whose + value is known at "compile" time. These are used (in combination with a + codegen state's variable substitution map) to simplifying access expressions + in :func:`get_access_info`. + + Note: inames are mapped to the :class:`Variable` version of themselves, + while temporary variables are mapped to their integer value + + .. parameter:: kernel + The kernel to check + .. parameter:: vec_iname + the vector iname + + """ + + # determine allowed symbols as non-vector inames + from pymbolic.primitives import Variable + allowed_symbols = dict((sym, Variable(sym)) for sym in kernel.all_inames() + if sym != vec_iname) + from loopy.kernel.instruction import Assignment + from loopy.tools import is_integer + from six import iteritems + + # and compile time integer temporaries + compile_time_assign = dict((str(insn.assignee), insn.expression) + for insn in kernel.instructions if + isinstance(insn, Assignment) and is_integer( + insn.expression)) + allowed_symbols.update( + dict((sym, compile_time_assign[sym]) for sym, var in iteritems( + kernel.temporary_variables) + # temporary variables w/ no initializer, no shape + if var.initializer is None and not var.shape + # compile time integers + and sym in compile_time_assign)) + return allowed_symbols + def map_subscript(self, expr): name = expr.aggregate.name @@ -114,29 +192,45 @@ class VectorizabilityChecker(RecursiveMapper): index = expr.index_tuple - from loopy.symbolic import get_dependencies + from loopy.symbolic import get_dependencies, DependencyMapper from loopy.kernel.array import VectorArrayDimTag - from pymbolic.primitives import Variable possible = None for i in range(len(var.shape)): - if ( - isinstance(var.dim_tags[i], VectorArrayDimTag) - and isinstance(index[i], Variable) - and index[i].name == self.vec_iname): + dep_mapper = DependencyMapper(composite_leaves=False) + deps = dep_mapper(index[i]) + # if we're on the vector index + if isinstance(var.dim_tags[i], VectorArrayDimTag): if var.shape[i] != self.vec_iname_length: raise Unvectorizable("vector length was mismatched") - if possible is None: - possible = True - - else: - if self.vec_iname in get_dependencies(index[i]): - raise Unvectorizable("vectorizing iname '%s' occurs in " - "unvectorized subscript axis %d (1-based) of " - "expression '%s'" - % (self.vec_iname, i+1, expr)) - break + possible = self.vec_iname in [str(x) for x in deps] + # or, if not vector index, and vector iname is present + elif self.vec_iname in set(x.name for x in deps): + # check whether we can simplify out the vector iname + context = dict((str(x), x) for x in deps if x.name != self.vec_iname) + allowed_symbols = self.compile_time_constants( + self.kernel, self.vec_iname) + + from pymbolic import substitute + from pymbolic.mapper.evaluator import UnknownVariableError + from loopy.tools import is_integer + for veci in range(self.vec_iname_length): + ncontext = context.copy() + ncontext[self.vec_iname] = veci + try: + idi = substitute(index[i], ncontext) + if not is_integer(idi) and not all( + x in allowed_symbols + for x in get_dependencies(idi)): + raise Unvectorizable( + "vectorizing iname '%s' occurs in " + "unvectorized subscript axis %d (1-based) of " + "expression '%s', and could not be simplified" + "to compile-time constants." + % (self.vec_iname, i+1, expr)) + except UnknownVariableError: + break return bool(possible) @@ -160,16 +254,31 @@ class VectorizabilityChecker(RecursiveMapper): return False def map_comparison(self, expr): - # FIXME: These actually can be vectorized: # https://www.khronos.org/registry/cl/sdk/1.0/docs/man/xhtml/relationalFunctions.html + # even better for OpenCL <, <=, >, >=, !=, == are all vectorizable by default + # (see: sec 6.3.d-6.d.3 in OpenCL-1.2 docs) + + if expr.operator in ["<", "<=", ">", ">=", "!=", "=="]: + return any([self.rec(x) for x in [expr.left, expr.right]]) + raise Unvectorizable() def map_logical_not(self, expr): - raise Unvectorizable() + # 6.3.h in OpenCL-1.2 docs + return self.rec(expr.child) + + def map_logical_and(self, expr): + # 6.3.h in OpenCL-1.2 docs + return any(self.rec(x) for x in expr.children) + + map_logical_or = map_logical_and - map_logical_and = map_logical_not - map_logical_or = map_logical_not + # sec 6.3.f in OpenCL-1.2 docs + map_bitwise_not = map_logical_not + map_bitwise_or = map_logical_and + map_bitwise_xor = map_logical_and + map_bitwise_and = map_logical_and def map_reduction(self, expr): # FIXME: Do this more carefully diff --git a/test/test_loopy.py b/test/test_loopy.py index accf9c1df..179e936f9 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2800,6 +2800,98 @@ def test_add_prefetch_works_in_lhs_index(): assert "a1_map" not in get_dependencies(insn.assignees) +def test_vectorizability(): + # check new vectorizability conditions + from loopy.kernel.array import VectorArrayDimTag + from loopy.kernel.data import VectorizeTag, filter_iname_tags_by_type + + def create_and_test(insn, exception=None, a=None, b=None): + a = np.zeros((3, 4), dtype=np.int32) if a is None else a + data = [lp.GlobalArg('a', shape=(12,), dtype=a.dtype)] + kwargs = dict(a=a) + if b is not None: + data += [lp.GlobalArg('b', shape=(12,), dtype=b.dtype)] + kwargs['b'] = b + names = [d.name for d in data] + + knl = lp.make_kernel(['{[i]: 0 <= i < 12}'], + """ + for i + %(insn)s + end + """ % dict(insn=insn), + data + ) + + knl = lp.split_iname(knl, 'i', 4, inner_tag='vec') + knl = lp.split_array_axis(knl, names, 0, 4) + knl = lp.tag_array_axes(knl, names, 'N0,vec') + knl = lp.preprocess_kernel(knl) + lp.generate_code_v2(knl).device_code() + assert knl.instructions[0].within_inames & set(['i_inner']) + assert isinstance(knl.args[0].dim_tags[-1], VectorArrayDimTag) + assert isinstance(knl.args[0].dim_tags[-1], VectorArrayDimTag) + assert filter_iname_tags_by_type(knl.iname_to_tags['i_inner'], VectorizeTag) + + def run(op_list=[], unary_operators=[], func_list=[], unary_funcs=[], + rvals=['1', 'a[i]']): + for op in op_list: + template = 'a[i] = a[i] %(op)s %(rval)s' \ + if op not in unary_operators else 'a[i] = %(op)s a[i]' + for rval in rvals: + create_and_test(template % dict(op=op, rval=rval)) + for func in func_list: + template = 'a[i] = %(func)s(a[i], %(rval)s)' \ + if func not in unary_funcs else 'a[i] = %(func)s(a[i])' + for rval in rvals: + create_and_test(template % dict(func=func, rval=rval)) + + # 1) comparisons + run(['>', '>=', '<', '<=', '==', '!=']) + + # 2) logical operators + run(['and', 'or', 'not'], ['not']) + + # 3) bitwise operators + # bitwise xor '^' not not implemented in codegen + run(['~', '|', '&'], ['~']) + + # 4) functions -- a random selection of the enabled math functions in opencl + run(func_list=['acos', 'exp10', 'atan2', 'round'], + unary_funcs=['round', 'acos', 'exp10']) + + # 5) remainders and floor division (use 4 instead of 1 to avoid pymbolic + # optimizing out the a[i] % 1) + run(['%', '//'], rvals=['a[i]', '4']) + + # 6) check vectorizability of subscripts w/ compile-time constants directly + # make a kernel + knl = lp.make_kernel(['{[i,j]: 0 <= i,j < 12}'], + """ + <> c = 4 + a[j, i + c] = 1 + """, + [lp.GlobalArg('a', shape=(12, 16), dtype=np.int32)] + ) + + knl = lp.split_iname(knl, 'i', 4, inner_tag='vec') + knl = lp.split_array_axis(knl, 'a', 1, 4) + knl = lp.tag_array_axes(knl, 'a', 'N1,N0,vec') + knl = lp.preprocess_kernel(knl) + + # get checker + from loopy.expression import VectorizabilityChecker + from loopy.codegen import Unvectorizable + # test CTC's + assert (set(VectorizabilityChecker.compile_time_constants(knl, 'i_inner').keys()) + == set(['j', 'c', 'i_outer'])) + # test that the VC doesn't throw an Unvectorizable + try: + VectorizabilityChecker(knl, 'i', 4)(knl.instructions[0].assignee) + except Unvectorizable: + assert False + + def test_check_for_variable_access_ordering(): knl = lp.make_kernel( "{[i]: 0<=i Date: Tue, 18 Sep 2018 15:15:05 -0400 Subject: [PATCH 2/2] Implement basic use of compile-time-constants in get_access_info & test Note: this functionality will be somewhat limited until the shuffle / load logic in a forthcoming MR gets added --- loopy/kernel/array.py | 52 ++++++++++++++++++++-------- loopy/target/c/codegen/expression.py | 29 +++++++++++----- loopy/target/ispc.py | 24 +++++++++---- test/test_loopy.py | 50 +++++++++++++++++--------- 4 files changed, 110 insertions(+), 45 deletions(-) diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 6bf733a84..867a315a4 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -1212,33 +1212,58 @@ class AccessInfo(ImmutableRecord): """ -def get_access_info(target, ary, index, eval_expr, vectorization_info): +def get_access_info(target, ary, index, var_subst_map, vectorization_info): """ :arg ary: an object of type :class:`ArrayBase` :arg index: a tuple of indices representing a subscript into ary + :arg var_subst_map: a context of variable substitutions from the calling codegen + state and potentially other compile-time "constants" (inames and + integer temporaries w/ known values), used in detection of loads / shuffles :arg vectorization_info: an instance of :class:`loopy.codegen.VectorizationInfo`, or *None*. """ import loopy as lp from pymbolic import var + from loopy.codegen import Unvectorizable + from loopy.symbolic import get_dependencies - def eval_expr_assert_integer_constant(i, expr): + def eval_expr_assert_constant(i, expr, kwargs): from pymbolic.mapper.evaluator import UnknownVariableError + # determine error type -- if vectorization_info is None, we're in the + # unvec fallback (and should raise a LoopyError) + # if vectorization_info is 'True', we should raise an Unvectorizable + # on failure + error_type = LoopyError if vectorization_info is None else Unvectorizable + from pymbolic import evaluate try: - result = eval_expr(expr) + result = evaluate(expr, kwargs) except UnknownVariableError as e: - raise LoopyError("When trying to index the array '%s' along axis " + if vectorization_info: + # failed vectorization + raise Unvectorizable( + "When trying to vectorize the array '%s' along axis " "%d (tagged '%s'), the index was not a compile-time " "constant (but it has to be in order for code to be " - "generated). You likely want to unroll the iname(s) '%s'." + "generated). You likely want to unroll the iname(s) '%s'" % (ary.name, i, ary.dim_tags[i], str(e))) + else: + raise LoopyError( + "When trying to unroll the array '%s' along axis " + "%d (tagged '%s'), the index was not an unrollable-iname " + "or constant (but it has to be in order for code to be " + "generated). You likely want to unroll/change array index(s)" + " '%s'" % (ary.name, i, ary.dim_tags[i], str(e))) if not is_integer(result): - raise LoopyError("subscript '%s[%s]' has non-constant " + # try to simplify further + from loopy.isl_helpers import simplify_via_aff + result = simplify_via_aff(result) + + if any([x not in var_subst_map for x in get_dependencies(result)]): + raise error_type("subscript '%s[%s]' has non-constant " "index for separate-array axis %d (0-based)" % ( ary.name, index, i)) - return result def apply_offset(sub): @@ -1289,7 +1314,7 @@ def get_access_info(target, ary, index, eval_expr, vectorization_info): for i, (idx, dim_tag) in enumerate(zip(index, ary.dim_tags)): if isinstance(dim_tag, SeparateArrayArrayDimTag): - idx = eval_expr_assert_integer_constant(i, idx) + idx = eval_expr_assert_constant(i, idx, var_subst_map) array_name += "_s%d" % idx # }}} @@ -1317,18 +1342,17 @@ def get_access_info(target, ary, index, eval_expr, vectorization_info): elif isinstance(dim_tag, VectorArrayDimTag): from pymbolic.primitives import Variable - if (vectorization_info is not None - and isinstance(index[i], Variable) + if (vectorization_info and isinstance(index[i], Variable) and index[i].name == vectorization_info.iname): # We'll do absolutely nothing here, which will result # in the vector being returned. pass else: - idx = eval_expr_assert_integer_constant(i, idx) - - assert vector_index is None - vector_index = idx + if vector_index is None: + # if we haven't processed the vector index yet + idx = eval_expr_assert_constant(i, idx, var_subst_map) + vector_index = idx else: raise LoopyError("unsupported array dim implementation tag '%s' " diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index dd2104d0c..10f48d3b1 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -182,16 +182,22 @@ class ExpressionToCExpressionMapper(IdentityMapper): ary = self.find_array(expr) - from loopy.kernel.array import get_access_info - from pymbolic import evaluate - from loopy.symbolic import simplify_using_aff index_tuple = tuple( simplify_using_aff(self.kernel, idx) for idx in expr.index_tuple) + from loopy.kernel.array import get_access_info + from loopy.expression import VectorizabilityChecker + var_subst_map = self.codegen_state.var_subst_map.copy() + if self.codegen_state.vectorization_info: + ctc_iname = self.codegen_state.vectorization_info.iname + ctc = VectorizabilityChecker.compile_time_constants( + self.codegen_state.kernel, + ctc_iname) + var_subst_map.update(ctc) + access_info = get_access_info(self.kernel.target, ary, index_tuple, - lambda expr: evaluate(expr, self.codegen_state.var_subst_map), - self.codegen_state.vectorization_info) + var_subst_map, self.codegen_state.vectorization_info) from loopy.kernel.data import ( ImageArg, ArrayArg, TemporaryVariable, ConstantArg) @@ -400,10 +406,17 @@ class ExpressionToCExpressionMapper(IdentityMapper): ary = self.find_array(arg) from loopy.kernel.array import get_access_info - from pymbolic import evaluate + from loopy.expression import VectorizabilityChecker + var_subst_map = self.codegen_state.var_subst_map.copy() + if self.codegen_state.vectorization_info: + ctc_iname = self.codegen_state.vectorization_info.iname + ctc = VectorizabilityChecker.compile_time_constants( + self.codegen_state.kernel, + ctc_iname) + var_subst_map.update(ctc) + access_info = get_access_info(self.kernel.target, ary, arg.index, - lambda expr: evaluate(expr, self.codegen_state.var_subst_map), - self.codegen_state.vectorization_info) + var_subst_map, self.codegen_state.vectorization_info) from loopy.kernel.data import ImageArg if isinstance(ary, ImageArg): diff --git a/loopy/target/ispc.py b/loopy/target/ispc.py index 771f2cdf6..0422f4816 100644 --- a/loopy/target/ispc.py +++ b/loopy/target/ispc.py @@ -109,11 +109,17 @@ class ExprToISPCExprMapper(ExpressionToCExpressionMapper): if lsize: lsize, = lsize from loopy.kernel.array import get_access_info - from pymbolic import evaluate + + var_subst_map = self.codegen_state.var_subst_map.copy() + if self.codegen_state.vectorization_info: + from loopy.expression import VectorizabilityChecker + ctc = VectorizabilityChecker.compile_time_constants( + self.codegen_state.kernel, + self.codegen_state.vectorization_info.iname) + var_subst_map.update(ctc) access_info = get_access_info(self.kernel.target, ary, expr.index, - lambda expr: evaluate(expr, self.codegen_state.var_subst_map), - self.codegen_state.vectorization_info) + var_subst_map, self.codegen_state.vectorization_info) subscript, = access_info.subscripts result = var(access_info.array_name)[ @@ -397,15 +403,21 @@ class ISPCASTBuilder(CASTBuilder): ary = ecm.find_array(lhs) from loopy.kernel.array import get_access_info - from pymbolic import evaluate from loopy.symbolic import simplify_using_aff index_tuple = tuple( simplify_using_aff(kernel, idx) for idx in lhs.index_tuple) + var_subst_map = codegen_state.var_subst_map.copy() + if codegen_state.vectorization_info: + from loopy.expression import VectorizabilityChecker + ctc = VectorizabilityChecker.compile_time_constants( + codegen_state.kernel, + codegen_state.vectorization_info.iname) + var_subst_map.update(ctc) + access_info = get_access_info(kernel.target, ary, index_tuple, - lambda expr: evaluate(expr, self.codegen_state.var_subst_map), - codegen_state.vectorization_info) + var_subst_map, codegen_state.vectorization_info) from loopy.kernel.data import ArrayArg, TemporaryVariable diff --git a/test/test_loopy.py b/test/test_loopy.py index 179e936f9..f4546086e 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2865,31 +2865,47 @@ def test_vectorizability(): run(['%', '//'], rvals=['a[i]', '4']) # 6) check vectorizability of subscripts w/ compile-time constants directly - # make a kernel - knl = lp.make_kernel(['{[i,j]: 0 <= i,j < 12}'], - """ - <> c = 4 - a[j, i + c] = 1 - """, - [lp.GlobalArg('a', shape=(12, 16), dtype=np.int32)] - ) + def _get_offset_kernel(as_temporary=True): + data = [lp.GlobalArg('a', shape=(12, 16), dtype=np.int32)] + if as_temporary: + pre = '<> c = 4' + else: + pre = '' + data.append(lp.ValueArg('c', dtype=np.int32)) - knl = lp.split_iname(knl, 'i', 4, inner_tag='vec') - knl = lp.split_array_axis(knl, 'a', 1, 4) - knl = lp.tag_array_axes(knl, 'a', 'N1,N0,vec') - knl = lp.preprocess_kernel(knl) + # make a kernel + knl = lp.make_kernel(['{[i,j]: 0 <= i,j < 12}'], + """ + {pre} + a[j, i + c] = 1 + """.format(pre=pre), data) + + knl = lp.split_iname(knl, 'i', 4, inner_tag='vec') + knl = lp.split_array_axis(knl, 'a', 1, 4) + knl = lp.tag_array_axes(knl, 'a', 'N1,N0,vec') + knl = lp.preprocess_kernel(knl) + return knl # get checker from loopy.expression import VectorizabilityChecker - from loopy.codegen import Unvectorizable + from loopy.diagnostic import LoopyError # test CTC's + knl = _get_offset_kernel() assert (set(VectorizabilityChecker.compile_time_constants(knl, 'i_inner').keys()) == set(['j', 'c', 'i_outer'])) # test that the VC doesn't throw an Unvectorizable - try: - VectorizabilityChecker(knl, 'i', 4)(knl.instructions[0].assignee) - except Unvectorizable: - assert False + VectorizabilityChecker(knl, 'i', 4)(knl.instructions[0].assignee) + + # and finally test that we can generate code + with pytest.raises(LoopyError): + # This test is broken in this MR as the shuffle / load logic in + # `get_access_info` is in a forthcoming MR + print(lp.generate_code_v2(knl).device_code()) + + # fix the parameter and allow vectorization + knl = _get_offset_kernel(False) + knl = lp.fix_parameters(knl, c=4) + print(lp.generate_code_v2(knl).device_code()) def test_check_for_variable_access_ordering(): -- GitLab