From 581d15cb2abcf161ddd882e77bcb15c19bb302c1 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 19 Sep 2019 00:06:04 -0500 Subject: [PATCH 01/26] picks callables and fortran related diff --- doc/tutorial.rst | 4 +- .../fortran/ipython-integration-demo.ipynb | 17 +--- examples/fortran/matmul.floopy | 4 +- examples/fortran/sparse.floopy | 4 +- examples/fortran/tagging.floopy | 4 +- examples/fortran/volumeKernel.floopy | 4 +- loopy/__init__.py | 14 +-- loopy/frontend/fortran/__init__.py | 53 ++++++++++- loopy/ipython_ext.py | 2 +- loopy/kernel/creation.py | 94 +++++++++---------- loopy/kernel/instruction.py | 4 +- loopy/symbolic.py | 12 +-- loopy/transform/callable.py | 32 +++++-- loopy/transform/fusion.py | 5 + test/test_callables.py | 71 ++++++-------- test/test_fortran.py | 8 +- test/test_numa_diff.py | 20 ++-- 17 files changed, 198 insertions(+), 154 deletions(-) diff --git a/doc/tutorial.rst b/doc/tutorial.rst index befa5e30b..e6ef54b66 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -1157,7 +1157,7 @@ this, :mod:`loopy` will complain that global barrier needs to be inserted: >>> cgr = lp.generate_code_v2(knl) Traceback (most recent call last): ... - loopy.diagnostic.MissingBarrierError: Dependency 'rotate depends on maketmp' (for variable 'arr') requires synchronization by a global barrier (add a 'no_sync_with' instruction option to state that no synchronization is needed) + loopy.diagnostic.MissingBarrierError: rotate_v1: Dependency 'rotate depends on maketmp' (for variable 'arr') requires synchronization by a global barrier (add a 'no_sync_with' instruction option to state that no synchronization is needed) The syntax for a inserting a global barrier instruction is ``... gbarrier``. :mod:`loopy` also supports manually inserting local @@ -1554,7 +1554,7 @@ information provided. Now we will count the operations: >>> op_map = lp.get_op_map(knl, subgroup_size=32) >>> print(lp.stringify_stats_mapping(op_map)) - Op(np:dtype('float32'), add, subgroup) : ... + Op(np:dtype('float32'), add, subgroup, loopy_kernel) : ... Each line of output will look roughly like:: diff --git a/examples/fortran/ipython-integration-demo.ipynb b/examples/fortran/ipython-integration-demo.ipynb index 7a5c8257b..1b0a9df8d 100644 --- a/examples/fortran/ipython-integration-demo.ipynb +++ b/examples/fortran/ipython-integration-demo.ipynb @@ -62,9 +62,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "split_amount = 128" @@ -91,7 +89,7 @@ "\n", "!$loopy begin\n", "!\n", - "! tr_fill, = lp.parse_fortran(SOURCE)\n", + "! tr_fill = lp.parse_fortran(SOURCE)\n", "! tr_fill = lp.split_iname(tr_fill, \"i\", split_amount,\n", "! outer_tag=\"g.0\", inner_tag=\"l.0\")\n", "! RESULT = [tr_fill]\n", @@ -107,15 +105,6 @@ "source": [ "print(tr_fill)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [] } ], "metadata": { @@ -134,7 +123,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.4" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/examples/fortran/matmul.floopy b/examples/fortran/matmul.floopy index 4b3552204..a8377bedd 100644 --- a/examples/fortran/matmul.floopy +++ b/examples/fortran/matmul.floopy @@ -13,7 +13,7 @@ subroutine dgemm(m,n,l,alpha,a,b,c) end subroutine !$loopy begin -! dgemm, = lp.parse_fortran(SOURCE, FILENAME) +! dgemm = lp.parse_fortran(SOURCE, FILENAME) ! dgemm = lp.split_iname(dgemm, "i", 16, ! outer_tag="g.0", inner_tag="l.1") ! dgemm = lp.split_iname(dgemm, "j", 8, @@ -24,5 +24,5 @@ end subroutine ! dgemm = lp.extract_subst(dgemm, "b_acc", "b[i1,i2]", parameters="i1, i2") ! dgemm = lp.precompute(dgemm, "a_acc", "k_inner,i_inner", default_tag="l.auto") ! dgemm = lp.precompute(dgemm, "b_acc", "j_inner,k_inner", default_tag="l.auto") -! RESULT = [dgemm] +! RESULT = dgemm !$loopy end diff --git a/examples/fortran/sparse.floopy b/examples/fortran/sparse.floopy index 18542e6b0..2b156bdd7 100644 --- a/examples/fortran/sparse.floopy +++ b/examples/fortran/sparse.floopy @@ -23,11 +23,11 @@ subroutine sparse(rowstarts, colindices, values, m, n, nvals, x, y) end !$loopy begin -! sparse, = lp.parse_fortran(SOURCE, FILENAME) +! sparse = lp.parse_fortran(SOURCE, FILENAME) ! sparse = lp.split_iname(sparse, "i", 128) ! sparse = lp.tag_inames(sparse, {"i_outer": "g.0"}) ! sparse = lp.tag_inames(sparse, {"i_inner": "l.0"}) ! sparse = lp.split_iname(sparse, "j", 4) ! sparse = lp.tag_inames(sparse, {"j_inner": "unr"}) -! RESULT = [sparse] +! RESULT = sparse !$loopy end diff --git a/examples/fortran/tagging.floopy b/examples/fortran/tagging.floopy index 87aacba68..c7ebb7566 100644 --- a/examples/fortran/tagging.floopy +++ b/examples/fortran/tagging.floopy @@ -23,13 +23,13 @@ end ! "factor 4.0", ! "real_type real*8", ! ]) -! fill, = lp.parse_fortran(SOURCE, FILENAME) +! fill = lp.parse_fortran(SOURCE, FILENAME) ! fill = lp.add_barrier(fill, "tag:init", "tag:mult", "gb1") ! fill = lp.split_iname(fill, "i", 128, ! outer_tag="g.0", inner_tag="l.0") ! fill = lp.split_iname(fill, "i_1", 128, ! outer_tag="g.0", inner_tag="l.0") -! RESULT = [fill] +! RESULT = fill ! !$loopy end diff --git a/examples/fortran/volumeKernel.floopy b/examples/fortran/volumeKernel.floopy index c5784b634..211c38049 100644 --- a/examples/fortran/volumeKernel.floopy +++ b/examples/fortran/volumeKernel.floopy @@ -67,7 +67,7 @@ end subroutine volumeKernel !$loopy begin ! -! volumeKernel, = lp.parse_fortran(SOURCE, FILENAME) +! volumeKernel = lp.parse_fortran(SOURCE, FILENAME) ! volumeKernel = lp.split_iname(volumeKernel, ! "e", 32, outer_tag="g.1", inner_tag="g.0") ! volumeKernel = lp.fix_parameters(volumeKernel, @@ -76,6 +76,6 @@ end subroutine volumeKernel ! i="l.0", j="l.1", k="l.2", ! i_1="l.0", j_1="l.1", k_1="l.2" ! )) -! RESULT = [volumeKernel] +! RESULT = volumeKernel ! !$loopy end diff --git a/loopy/__init__.py b/loopy/__init__.py index 1439cb1ff..058bc93ef 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -130,10 +130,10 @@ from loopy.type_inference import infer_unknown_types from loopy.preprocess import (preprocess_kernel, realize_reduction, preprocess_program) from loopy.schedule import generate_loop_schedules, get_one_scheduled_kernel -from loopy.statistics import (ToCountMap, CountGranularity, - Op, MemAccess, get_op_map, get_mem_access_map, - get_synchronization_map, - gather_access_footprints, gather_access_footprint_bytes) +from loopy.statistics import (ToCountMap, ToCountPolynomialMap, + CountGranularity, stringify_stats_mapping, Op, MemAccess, get_op_map, + get_mem_access_map, get_synchronization_map, + gather_access_footprints, gather_access_footprint_bytes, Sync) from loopy.codegen import ( PreambleInfo, generate_code, generate_code_v2, generate_body) @@ -269,9 +269,11 @@ __all__ = [ "PreambleInfo", "generate_code", "generate_code_v2", "generate_body", - "ToCountMap", "CountGranularity", "Op", - "MemAccess", "get_op_map", "get_mem_access_map", "get_synchronization_map", + "ToCountMap", "ToCountPolynomialMap", "CountGranularity", + "stringify_stats_mapping", "Op", "MemAccess", "get_op_map", + "get_mem_access_map", "get_synchronization_map", "gather_access_footprints", "gather_access_footprint_bytes", + "Sync", "CompiledKernel", diff --git a/loopy/frontend/fortran/__init__.py b/loopy/frontend/fortran/__init__.py index 3516ca29a..74c1ebf54 100644 --- a/loopy/frontend/fortran/__init__.py +++ b/loopy/frontend/fortran/__init__.py @@ -241,10 +241,54 @@ def parse_transformed_fortran(source, free_form=True, strict=True, return proc_dict["RESULT"] +def _add_assignees_to_calls(knl, all_kernels): + new_insns = [] + subroutine_dict = dict((kernel.name, kernel) for kernel in all_kernels) + from loopy.kernel.instruction import (Assignment, CallInstruction, + CInstruction, _DataObliviousInstruction, + modify_assignee_for_array_call) + from pymbolic.primitives import Call, Variable + + for insn in knl.instructions: + if isinstance(insn, CallInstruction): + if isinstance(insn.expression, Call) and ( + insn.expression.function.name in subroutine_dict): + assignees = [] + new_params = [] + subroutine = subroutine_dict[insn.expression.function.name] + for par, arg in zip(insn.expression.parameters, subroutine.args): + if arg.name in subroutine.get_written_variables(): + par = modify_assignee_for_array_call(par) + assignees.append(par) + if arg.name in subroutine.get_read_variables(): + new_params.append(par) + if arg.name not in (subroutine.get_written_variables() | + subroutine.get_read_variables()): + new_params.append(par) + + new_insns.append( + insn.copy( + assignees=tuple(assignees), + expression=Variable( + insn.expression.function.name)(*new_params))) + else: + new_insns.append(insn) + pass + elif isinstance(insn, (Assignment, CInstruction, + _DataObliviousInstruction)): + new_insns.append(insn) + else: + raise NotImplementedError(type(insn).__name__) + + return knl.copy(instructions=new_insns) + + def parse_fortran(source, filename="", free_form=None, strict=None, - seq_dependencies=None, auto_dependencies=None, target=None): + seq_dependencies=None, auto_dependencies=None, target=None, + return_list_of_knls=False): """ - :returns: a :class:`loopy.Program` + :returns: an instance of :class:`list` of :class:`loopy.LoopKernel`s if + *return_list_of_knls* is True else a :class:`loopy.Program`. """ parse_plog = ProcessLogger(logger, "parsing fortran file '%s'" % filename) @@ -286,6 +330,11 @@ def parse_fortran(source, filename="", free_form=None, strict=None, kernels = f2loopy.make_kernels(seq_dependencies=seq_dependencies) + if return_list_of_knls: + return kernels + + kernels = [_add_assignees_to_calls(knl, kernels) for knl in kernels] + from loopy.kernel.tools import identify_root_kernel from loopy.program import make_program from loopy.transform.callable import register_callable_kernel diff --git a/loopy/ipython_ext.py b/loopy/ipython_ext.py index ec1b10f1f..e44b183ed 100644 --- a/loopy/ipython_ext.py +++ b/loopy/ipython_ext.py @@ -9,7 +9,7 @@ import loopy as lp class LoopyMagics(Magics): @cell_magic def fortran_kernel(self, line, cell): - result = lp.parse_fortran(cell) + result = lp.parse_fortran(cell, return_list_of_knls=True) for knl in result: self.shell.user_ns[knl.name] = knl diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 1f896bb97..f36a90575 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -37,6 +37,7 @@ from loopy.kernel.data import ( SubstitutionRule, AddressSpace, ValueArg) from loopy.kernel.instruction import (CInstruction, _DataObliviousInstruction, CallInstruction) +from loopy.program import iterate_over_kernels_if_given_program from loopy.diagnostic import LoopyError, warn_with_kernel import islpy as isl from islpy import dim_type @@ -1753,6 +1754,7 @@ def add_inferred_inames(knl): # {{{ apply single-writer heuristic +@iterate_over_kernels_if_given_program def apply_single_writer_depencency_heuristic(kernel, warn_if_used=True): logger.debug("%s: default deps" % kernel.name) @@ -2175,56 +2177,55 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): # {{{ handle kernel language version - if not is_callee_kernel: - from loopy.version import LANGUAGE_VERSION_SYMBOLS + from loopy.version import LANGUAGE_VERSION_SYMBOLS - version_to_symbol = dict( - (getattr(loopy.version, lvs), lvs) - for lvs in LANGUAGE_VERSION_SYMBOLS) + version_to_symbol = dict( + (getattr(loopy.version, lvs), lvs) + for lvs in LANGUAGE_VERSION_SYMBOLS) - lang_version = kwargs.pop("lang_version", None) - if lang_version is None: - # {{{ peek into caller's module to look for LOOPY_KERNEL_LANGUAGE_VERSION + lang_version = kwargs.pop("lang_version", None) + if lang_version is None: + # {{{ peek into caller's module to look for LOOPY_KERNEL_LANGUAGE_VERSION - # This *is* gross. But it seems like the right thing interface-wise. - import inspect - caller_globals = inspect.currentframe().f_back.f_globals + # This *is* gross. But it seems like the right thing interface-wise. + import inspect + caller_globals = inspect.currentframe().f_back.f_globals - for ver_sym in LANGUAGE_VERSION_SYMBOLS: - try: - lang_version = caller_globals[ver_sym] - break - except KeyError: - pass + for ver_sym in LANGUAGE_VERSION_SYMBOLS: + try: + lang_version = caller_globals[ver_sym] + break + except KeyError: + pass - # }}} + # }}} - if lang_version is None: - from warnings import warn - from loopy.diagnostic import LoopyWarning - from loopy.version import ( - MOST_RECENT_LANGUAGE_VERSION, - FALLBACK_LANGUAGE_VERSION) - warn("'lang_version' was not passed to make_kernel(). " - "To avoid this warning, pass " - "lang_version={ver} in this invocation. " - "(Or say 'from loopy.version import " - "{sym_ver}' in " - "the global scope of the calling frame.)" - .format( - ver=MOST_RECENT_LANGUAGE_VERSION, - sym_ver=version_to_symbol[MOST_RECENT_LANGUAGE_VERSION] - ), - LoopyWarning, stacklevel=2) - - lang_version = FALLBACK_LANGUAGE_VERSION - - if lang_version not in version_to_symbol: - raise LoopyError("Language version '%s' is not known." % (lang_version,)) - if lang_version >= (2018, 1): - options = options.copy(enforce_variable_access_ordered=True) - if lang_version >= (2018, 2): - options = options.copy(ignore_boostable_into=True) + if lang_version is None: + from warnings import warn + from loopy.diagnostic import LoopyWarning + from loopy.version import ( + MOST_RECENT_LANGUAGE_VERSION, + FALLBACK_LANGUAGE_VERSION) + warn("'lang_version' was not passed to make_kernel(). " + "To avoid this warning, pass " + "lang_version={ver} in this invocation. " + "(Or say 'from loopy.version import " + "{sym_ver}' in " + "the global scope of the calling frame.)" + .format( + ver=MOST_RECENT_LANGUAGE_VERSION, + sym_ver=version_to_symbol[MOST_RECENT_LANGUAGE_VERSION] + ), + LoopyWarning, stacklevel=2) + + lang_version = FALLBACK_LANGUAGE_VERSION + + if lang_version not in version_to_symbol: + raise LoopyError("Language version '%s' is not known." % (lang_version,)) + if lang_version >= (2018, 1): + options = options.copy(enforce_variable_access_ordered=True) + if lang_version >= (2018, 2): + options = options.copy(ignore_boostable_into=True) # }}} @@ -2382,11 +2383,6 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): def make_function(*args, **kwargs): - lang_version = kwargs.pop('lang_version', None) - if lang_version: - raise LoopyError("lang_version should be set for program, not " - "functions.") - kwargs['is_callee_kernel'] = True return make_kernel(*args, **kwargs) diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 9d85f5e84..1ba0dc7ec 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -1208,7 +1208,7 @@ def is_array_call(assignees, expression): return False -def modify_assignee_assignee_for_array_call(assignee): +def modify_assignee_for_array_call(assignee): """ Converts the assignee subscript or variable as a SubArrayRef. """ @@ -1258,7 +1258,7 @@ def make_assignment(assignees, expression, temp_var_types=None, **kwargs): # assignee as an instance of SubArrayRef. If not given as a # SubArrayRef return CallInstruction( - assignees=tuple(modify_assignee_assignee_for_array_call( + assignees=tuple(modify_assignee_for_array_call( assignee) for assignee in assignees), expression=expression, temp_var_types=temp_var_types, diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 6f3c6f2be..870f9fc2c 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -719,7 +719,7 @@ class RuleArgument(LoopyExpressionBase): mapper_method = intern("map_rule_argument") -class ResolvedFunction(p.Expression): +class ResolvedFunction(LoopyExpressionBase): """ A function invocation whose definition is known in a :mod:`loopy` kernel. Each instance of :class:`loopy.symbolic.ResolvedFunction` in an expression @@ -758,8 +758,8 @@ class ResolvedFunction(p.Expression): def __getinitargs__(self): return (self.function, ) - def stringifier(self): - return StringifyMapper + def make_stringifier(self, originating_stringifier=None): + return StringifyMapper() mapper_method = intern("map_resolved_function") @@ -807,7 +807,7 @@ class SweptInameStrideCollector(CoefficientCollectorBase): return super(SweptInameStrideCollector, self).map_algebraic_leaf(expr) -class SubArrayRef(p.Expression): +class SubArrayRef(LoopyExpressionBase): """ An algebraic expression to map an affine memory layout pattern (known as sub-arary) as consecutive elements of the sweeping axes which are defined @@ -871,8 +871,8 @@ class SubArrayRef(p.Expression): and other.subscript == self.subscript and other.swept_inames == self.swept_inames) - def stringifier(self): - return StringifyMapper + def make_stringifier(self, originating_stringifier=None): + return StringifyMapper() mapper_method = intern("map_sub_array_ref") diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 479843697..7534818d7 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -50,7 +50,7 @@ __doc__ = """ # {{{ register function lookup -def _resolved_callables_from_function_lookup(program, +def _resolve_callables_from_function_lookup(program, func_id_to_in_kernel_callable_mapper): """ Returns a copy of *program* with the expression nodes marked "Resolved" @@ -124,7 +124,7 @@ def register_function_id_to_in_knl_callable_mapper(program, new_func_id_mappers = program.func_id_to_in_knl_callable_mappers + ( [func_id_to_in_knl_callable_mapper]) - program = _resolved_callables_from_function_lookup(program, + program = _resolve_callables_from_function_lookup(program, func_id_to_in_knl_callable_mapper) new_program = program.copy( @@ -173,11 +173,17 @@ def register_callable_kernel(program, callee_kernel): # the number of assigness in the callee kernel intructions. expected_num_assignees = len([arg for arg in callee_kernel.args if arg.name in callee_kernel.get_written_variables()]) - expected_num_parameters = len([arg for arg in callee_kernel.args if + expected_max_num_parameters = len([arg for arg in callee_kernel.args if arg.name in callee_kernel.get_read_variables()]) + len( [arg for arg in callee_kernel.args if arg.name not in (callee_kernel.get_read_variables() | callee_kernel.get_written_variables())]) + expected_min_num_parameters = len([arg for arg in callee_kernel.args if + arg.name in callee_kernel.get_read_variables() and arg.name not in + callee_kernel.get_written_variables()]) + len( + [arg for arg in callee_kernel.args if arg.name not in + (callee_kernel.get_read_variables() | + callee_kernel.get_written_variables())]) for in_knl_callable in program.callables_table.values(): if isinstance(in_knl_callable, CallableKernel): caller_kernel = in_knl_callable.subkernel @@ -195,11 +201,21 @@ def register_callable_kernel(program, callee_kernel): "match." % ( callee_kernel.name, insn.id)) if len(insn.expression.parameters+tuple( - kw_parameters.values())) != expected_num_parameters: - raise LoopyError("The number of expected arguments " - "for the callee kernel %s and the number of " - "parameters in instruction %s do not match." - % (callee_kernel.name, insn.id)) + kw_parameters.values())) > expected_max_num_parameters: + raise LoopyError("The number of" + " parameters in instruction '%s' exceed" + " the max. number of arguments possible" + " for the callee kernel '%s' => arg matching" + " not possible." + % (insn.id, callee_kernel.name)) + if len(insn.expression.parameters+tuple( + kw_parameters.values())) < expected_min_num_parameters: + raise LoopyError("The number of" + " parameters in instruction '%s' is less than" + " the min. number of arguments possible" + " for the callee kernel '%s' => arg matching" + " not possible." + % (insn.id, callee_kernel.name)) elif isinstance(insn, (MultiAssignmentBase, CInstruction, _DataObliviousInstruction)): diff --git a/loopy/transform/fusion.py b/loopy/transform/fusion.py index 9b83f242b..45e9c0a06 100644 --- a/loopy/transform/fusion.py +++ b/loopy/transform/fusion.py @@ -419,6 +419,11 @@ def fuse_kernels(programs, suffixes=None, data_flow=None): *data_flow* was added in version 2016.2 """ + from loopy.program import make_program + + programs = [make_program(knl) if isinstance(knl, LoopKernel) else knl for + knl in programs] + # all the resolved functions in programs must be registered in # main_callables_table main_prog_callables_info = ( diff --git a/test/test_callables.py b/test/test_callables.py index f2f3acbd6..731593ea3 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -63,38 +63,35 @@ def test_register_function_lookup(ctx_factory): def test_register_knl(ctx_factory, inline): ctx = ctx_factory() queue = cl.CommandQueue(ctx) - n = 2 ** 4 + n = 4 x = np.random.rand(n, n, n, n, n) y = np.random.rand(n, n, n, n, n) grandchild_knl = lp.make_function( - "{[i, j]:0<= i, j< 16}", + "{[i, j]:0<= i, j< 4}", """ c[i, j] = 2*a[i, j] + 3*b[i, j] """, name='linear_combo1') child_knl = lp.make_function( - "{[i, j]:0<=i, j < 16}", + "{[i, j]:0<=i, j < 4}", """ [i, j]: g[i, j] = linear_combo1([i, j]: e[i, j], [i, j]: f[i, j]) """, name='linear_combo2') parent_knl = lp.make_kernel( - "{[i, j, k, l, m]: 0<=i, j, k, l, m<16}", + "{[i, j, k, l, m]: 0<=i, j, k, l, m<4}", """ [j, l]: z[i, j, k, l, m] = linear_combo2([j, l]: x[i, j, k, l, m], [j, l]: y[i, j, k, l, m]) """, kernel_data=[ lp.GlobalArg( - name='x', + name='x, y', dtype=np.float64, - shape=(16, 16, 16, 16, 16)), - lp.GlobalArg( - name='y', - dtype=np.float64, - shape=(16, 16, 16, 16, 16)), '...'], + shape=(n, n, n, n, n)), + '...'] ) knl = lp.register_callable_kernel( @@ -115,36 +112,29 @@ def test_register_knl(ctx_factory, inline): def test_slices_with_negative_step(ctx_factory, inline): ctx = ctx_factory() queue = cl.CommandQueue(ctx) - n = 2 ** 4 + n = 4 x = np.random.rand(n, n, n, n, n) y = np.random.rand(n, n, n, n, n) child_knl = lp.make_function( - "{[i, j]:0<=i, j < 16}", + "{[i, j]:0<=i, j < 4}", """ g[i, j] = 2*e[i, j] + 3*f[i, j] """, name="linear_combo") parent_knl = lp.make_kernel( - "{[i, k, m]: 0<=i, k, m<16}", + "{[i, k, m]: 0<=i, k, m<4}", """ - z[i, 15:-1:-1, k, :, m] = linear_combo(x[i, :, k, :, m], + z[i, 3:-1:-1, k, :, m] = linear_combo(x[i, :, k, :, m], y[i, :, k, :, m]) """, kernel_data=[ lp.GlobalArg( - name='x', - dtype=np.float64, - shape=(16, 16, 16, 16, 16)), - lp.GlobalArg( - name='y', - dtype=np.float64, - shape=(16, 16, 16, 16, 16)), - lp.GlobalArg( - name='z', + name='x, y, z', dtype=np.float64, - shape=(16, 16, 16, 16, 16)), '...'], + shape=(n, n, n, n, n)), + '...'] ) knl = lp.register_callable_kernel( @@ -163,7 +153,7 @@ def test_register_knl_with_call_with_kwargs(ctx_factory, inline): ctx = ctx_factory() queue = cl.CommandQueue(ctx) - n = 2 ** 2 + n = 4 a_dev = cl.clrandom.rand(queue, (n, n, n, n, n), np.float32) b_dev = cl.clrandom.rand(queue, (n, n, n, n, n), np.float32) @@ -215,27 +205,27 @@ def test_register_knl_with_hw_axes(ctx_factory, inline): ctx = ctx_factory() queue = cl.CommandQueue(ctx) - n = 2 ** 5 + n = 4 x_dev = cl.clrandom.rand(queue, (n, n, n, n, n), np.float64) y_dev = cl.clrandom.rand(queue, (n, n, n, n, n), np.float64) callee_knl = lp.make_function( - "{[i, j]:0<=i, j < 32}", + "{[i, j]:0<=i, j < 4}", """ g[i, j] = 2*e[i, j] + 3*f[i, j] """, name='linear_combo') - callee_knl = lp.split_iname(callee_knl, "i", 2, inner_tag="l.0", outer_tag="g.0") + callee_knl = lp.split_iname(callee_knl, "i", 1, inner_tag="l.0", outer_tag="g.0") caller_knl = lp.make_kernel( - "{[i, j, k, l, m]: 0<=i, j, k, l, m<32}", + "{[i, j, k, l, m]: 0<=i, j, k, l, m<4}", """ [j, l]: z[i, j, k, l, m] = linear_combo([j, l]: x[i, j, k, l, m], [j, l]: y[i, j, k, l, m]) """ ) - caller_knl = lp.split_iname(caller_knl, "i", 8, inner_tag="l.1", outer_tag="g.1") + caller_knl = lp.split_iname(caller_knl, "i", 4, inner_tag="l.1", outer_tag="g.1") knl = lp.register_callable_kernel( caller_knl, callee_knl) @@ -252,8 +242,8 @@ def test_register_knl_with_hw_axes(ctx_factory, inline): x_host = x_dev.get() y_host = y_dev.get() - assert gsize == (16, 4) - assert lsize == (2, 8) + assert gsize == (4, 1) + assert lsize == (1, 4) assert np.linalg.norm(2*x_host+3*y_host-out['z'].get())/np.linalg.norm( 2*x_host+3*y_host) < 1e-15 @@ -484,13 +474,13 @@ def test_empty_sub_array_refs(ctx_factory, inline): def test_array_inputs_to_callee_kernels(ctx_factory, inline): ctx = ctx_factory() queue = cl.CommandQueue(ctx) - n = 2 ** 4 + n = 2 ** 3 x = np.random.rand(n, n) y = np.random.rand(n, n) child_knl = lp.make_function( - "{[i, j]:0<=i, j < 16}", + "{[i, j]:0<=i, j < 8}", """ g[i, j] = 2*e[i, j] + 3*f[i, j] """, name="linear_combo") @@ -502,17 +492,10 @@ def test_array_inputs_to_callee_kernels(ctx_factory, inline): """, kernel_data=[ lp.GlobalArg( - name='x', - dtype=np.float64, - shape=(16, 16)), - lp.GlobalArg( - name='y', - dtype=np.float64, - shape=(16, 16)), - lp.GlobalArg( - name='z', + name='x, y, z', dtype=np.float64, - shape=(16, 16)), '...'], + shape=(n, n)), + '...'] ) knl = lp.register_callable_kernel( diff --git a/test/test_fortran.py b/test/test_fortran.py index 437199810..1ab28409b 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -533,9 +533,11 @@ def test_parse_and_fuse_two_kernels(): !$loopy begin ! - ! prg = lp.parse_fortran(SOURCE) - ! fill = prg["fill"] - ! twice = prg["twice"] + ! # FIXME: correct this after the "Module" is done. + ! # prg = lp.parse_fortran(SOURCE) + ! # fill = prg["fill"] + ! # twice = prg["twice"] + ! fill, twice = lp.parse_fortran(SOURCE, return_list_of_knls=True) ! knl = lp.fuse_kernels((fill, twice)) ! print(knl) ! RESULT = knl diff --git a/test/test_numa_diff.py b/test/test_numa_diff.py index 1ba44e77e..55a2d2e11 100644 --- a/test/test_numa_diff.py +++ b/test/test_numa_diff.py @@ -60,7 +60,8 @@ def test_gnuma_horiz_kernel(ctx_factory, ilp_multiple, Nq, opt_level): # noqa source = source.replace("datafloat", "real*4") hsv_r, hsv_s = [ - knl for knl in lp.parse_fortran(source, filename, seq_dependencies=False) + knl for knl in lp.parse_fortran(source, filename, + seq_dependencies=False, return_list_of_knls=True) if "KernelR" in knl.name or "KernelS" in knl.name ] hsv_r = lp.tag_instructions(hsv_r, "rknl") @@ -229,6 +230,15 @@ def test_gnuma_horiz_kernel(ctx_factory, ilp_multiple, Nq, opt_level): # noqa hsv = tap_hsv + hsv = lp.set_options(hsv, + ignore_boostable_into=True, + cl_build_options=[ + "-cl-denorms-are-zero", + "-cl-fast-relaxed-math", + "-cl-finite-math-only", + "-cl-mad-enable", + "-cl-no-signed-zeros"]) + if 1: print("OPS") op_map = lp.get_op_map(hsv, subgroup_size=32) @@ -238,14 +248,6 @@ def test_gnuma_horiz_kernel(ctx_factory, ilp_multiple, Nq, opt_level): # noqa gmem_map = lp.get_mem_access_map(hsv, subgroup_size=32).to_bytes() print(lp.stringify_stats_mapping(gmem_map)) - hsv = lp.set_options(hsv, cl_build_options=[ - "-cl-denorms-are-zero", - "-cl-fast-relaxed-math", - "-cl-finite-math-only", - "-cl-mad-enable", - "-cl-no-signed-zeros", - ]) - # FIXME: renaming's a bit tricky in this program model. # add a simple transformation for it # hsv = hsv.copy(name="horizontalStrongVolumeKernel") -- GitLab From 6857c4ba818ac896ee677ac4dd4c69c90bb20108 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 19 Sep 2019 15:32:08 -0500 Subject: [PATCH 02/26] adds some helpful comments --- loopy/frontend/fortran/__init__.py | 12 ++++++++++++ loopy/transform/callable.py | 3 +++ 2 files changed, 15 insertions(+) diff --git a/loopy/frontend/fortran/__init__.py b/loopy/frontend/fortran/__init__.py index 74c1ebf54..bc360b996 100644 --- a/loopy/frontend/fortran/__init__.py +++ b/loopy/frontend/fortran/__init__.py @@ -242,6 +242,18 @@ def parse_transformed_fortran(source, free_form=True, strict=True, def _add_assignees_to_calls(knl, all_kernels): + """ + Returns a copy of *knl* coming from the fortran parser adjusted to the + loopy specification that written variables of a call must appear in the + assignee. + + :param knl: An instance of :class:`loopy.LoopKernel`, which have incorrect + calls to the kernels in *all_kernels* by stuffing both the input and + output arguments into parameters. + + :param all_kernels: An instance of :class:`list` of loopy kernels which + may be called by *kernel*. + """ new_insns = [] subroutine_dict = dict((kernel.name, kernel) for kernel in all_kernels) from loopy.kernel.instruction import (Assignment, CallInstruction, diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 7534818d7..e0f4a79d7 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -173,6 +173,9 @@ def register_callable_kernel(program, callee_kernel): # the number of assigness in the callee kernel intructions. expected_num_assignees = len([arg for arg in callee_kernel.args if arg.name in callee_kernel.get_written_variables()]) + + # can only predict the range of actual number of parameters to a kernel + # call, as a variable intended for pure output can be read expected_max_num_parameters = len([arg for arg in callee_kernel.args if arg.name in callee_kernel.get_read_variables()]) + len( [arg for arg in callee_kernel.args if arg.name not in -- GitLab From 4d5f37e001c63de2f3adcae79b2c19fabbc3df2d Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 19 Sep 2019 15:32:28 -0500 Subject: [PATCH 03/26] adds in-place update test --- test/test_callables.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/test_callables.py b/test/test_callables.py index 731593ea3..ce6b89e36 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -564,6 +564,29 @@ def test_unknown_stride_to_callee(): print(lp.generate_code_v2(prog).device_code()) +def test_argument_matching_for_inplace_update(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + twice = lp.make_function( + "{[i]: 0<=i<10}", + """ + x[i] = 2*x[i] + """, name='twice') + + knl = lp.make_kernel( + "{:}", + """ + x[:] = twice(x[:]) + """, [lp.GlobalArg('x', shape=(10,), dtype=np.float64)]) + + knl = lp.register_callable_kernel(knl, twice) + + x = np.random.randn(10) + evt, (out, ) = knl(queue, np.copy(x)) + + assert np.allclose(2*x, out) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab From df475fcf3c0c1ef57c26ee769d99a7e080b2f022 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 19 Sep 2019 17:08:46 -0500 Subject: [PATCH 04/26] KernelArgument.is_output_only -> KernelArgument.is_output --- loopy/auto_test.py | 2 +- loopy/frontend/fortran/translator.py | 2 +- loopy/target/execution.py | 2 +- loopy/transform/make_scalar.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/loopy/auto_test.py b/loopy/auto_test.py index 4bca7ebdb..b5039bd2c 100644 --- a/loopy/auto_test.py +++ b/loopy/auto_test.py @@ -118,7 +118,7 @@ def make_ref_args(program, impl_arg_info, queue, parameters): shape = evaluate_shape(arg.unvec_shape, parameters) dtype = kernel_arg.dtype - is_output = kernel_arg.is_output_only + is_output = kernel_arg.is_output if arg.arg_class is ImageArg: storage_array = ary = cl_array.empty( diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py index 66961ce70..949a3d4cc 100644 --- a/loopy/frontend/fortran/translator.py +++ b/loopy/frontend/fortran/translator.py @@ -763,7 +763,7 @@ class F2LoopyTranslator(FTreeWalkerBase): arg_name, dtype=sub.get_type(arg_name), shape=sub.get_loopy_shape(arg_name), - is_output_only=False, + is_output=False, )) else: kernel_data.append( diff --git a/loopy/target/execution.py b/loopy/target/execution.py index 9d1d14376..96f6e065c 100644 --- a/loopy/target/execution.py +++ b/loopy/target/execution.py @@ -725,7 +725,7 @@ class KernelExecutorBase(object): self.packing_controller = SeparateArrayPackingController(program) self.output_names = tuple(arg.name for arg in self.program.args - if arg.is_output_only) + if arg.is_output) self.has_runtime_typed_args = any( arg.dtype is None diff --git a/loopy/transform/make_scalar.py b/loopy/transform/make_scalar.py index ab91fdf78..d0e7d1bc2 100644 --- a/loopy/transform/make_scalar.py +++ b/loopy/transform/make_scalar.py @@ -23,7 +23,7 @@ def make_scalar(kernel, var_name): kernel = ScalarChanger(rule_mapping_context, var_name).map_kernel(kernel) new_args = [ValueArg(arg.name, arg.dtype, target=arg.target, - is_output_only=arg.is_output_only) if arg.name == var_name else arg for + is_output=arg.is_output) if arg.name == var_name else arg for arg in kernel.args] new_temps = dict((tv.name, tv.copy(shape=(), dim_tags=None)) if tv.name == var_name else (tv.name, tv) for tv in -- GitLab From 71d7541dc55f5a2f2e1fefa83628543fe634ef53 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 19 Sep 2019 17:08:51 -0500 Subject: [PATCH 05/26] Adds a kernel argument attribute is_input - Transmits changes in the function interface so that they also use is_input while performing caller<->callee argument matching - Makes changes in the test cases so that they set is_output, is_input correctly --- loopy/kernel/creation.py | 4 ++-- loopy/kernel/data.py | 27 ++++++++++++++------- loopy/kernel/function_interface.py | 13 ++++------ loopy/kernel/tools.py | 38 +++++++++++++++++++++--------- loopy/transform/callable.py | 36 ++++++---------------------- test/test_callables.py | 16 ++++++++----- 6 files changed, 70 insertions(+), 64 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index f36a90575..4be7e06b8 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -2367,8 +2367,8 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): check_for_duplicate_names(knl) check_written_variable_names(knl) - from loopy.kernel.tools import infer_args_are_output_only - knl = infer_args_are_output_only(knl) + from loopy.kernel.tools import infer_args_are_input_output + knl = infer_args_are_input_output(knl) from loopy.preprocess import prepare_for_caching knl = prepare_for_caching(knl) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 4c0959111..15a77b809 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -338,7 +338,8 @@ class KernelArgument(ImmutableRecord): dtype = None kwargs["dtype"] = dtype - kwargs["is_output_only"] = kwargs.pop("is_output_only", None) + kwargs["is_output"] = kwargs.pop("is_output", None) + kwargs["is_input"] = kwargs.pop("is_input", None) ImmutableRecord.__init__(self, **kwargs) @@ -351,20 +352,27 @@ class ArrayArg(ArrayBase, KernelArgument): An attribute of :class:`AddressSpace` defining the address space in which the array resides. - .. attribute:: is_output_only + .. attribute:: is_output An instance of :class:`bool`. If set to *True*, recorded to be returned from the kernel. + + .. attribute:: is_input + + An instance of :class:`bool`. If set to *True*, expected to be + provided by the user. """) allowed_extra_kwargs = [ "address_space", - "is_output_only"] + "is_output", + "is_input"] def __init__(self, *args, **kwargs): if "address_space" not in kwargs: raise TypeError("'address_space' must be specified") - kwargs["is_output_only"] = kwargs.pop("is_output_only", None) + kwargs["is_output"] = kwargs.pop("is_output", None) + kwargs["is_input"] = kwargs.pop("is_input", None) super(ArrayArg, self).__init__(*args, **kwargs) @@ -392,7 +400,8 @@ class ArrayArg(ArrayBase, KernelArgument): """ super(ArrayArg, self).update_persistent_hash(key_hash, key_builder) key_builder.rec(key_hash, self.address_space) - key_builder.rec(key_hash, self.is_output_only) + key_builder.rec(key_hash, self.is_output) + key_builder.rec(key_hash, self.is_input) # Making this a function prevents incorrect use in isinstance. @@ -413,7 +422,8 @@ class ConstantArg(ArrayBase, KernelArgument): max_target_axes = 1 # Constant Arg cannot be an output - is_output_only = False + is_output = False + is_input = True def get_arg_decl(self, ast_builder, name_suffix, shape, dtype, is_written): return ast_builder.get_constant_arg_decl(self.name + name_suffix, shape, @@ -436,13 +446,14 @@ class ImageArg(ArrayBase, KernelArgument): class ValueArg(KernelArgument): def __init__(self, name, dtype=None, approximately=1000, target=None, - is_output_only=False): + is_output=False, is_input=True): KernelArgument.__init__(self, name=name, dtype=dtype, approximately=approximately, target=target, - is_output_only=is_output_only) + is_output=is_output, + is_input=is_input) def __str__(self): import loopy as lp diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index d8c120db8..4b2d18ec5 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -226,16 +226,13 @@ def get_kw_pos_association(kernel): write_count = -1 for arg in kernel.args: - if arg.name in kernel.get_written_variables(): + if arg.is_output: kw_to_pos[arg.name] = write_count pos_to_kw[write_count] = arg.name write_count -= 1 - if arg.name in kernel.get_read_variables(): - kw_to_pos[arg.name] = read_count - pos_to_kw[read_count] = arg.name - read_count += 1 - if not (arg.name in kernel.get_read_variables() or arg.name in - kernel.get_written_variables()): + if arg.is_input: + # if an argument is both input and output then the input is given + # more significance in kw_to_pos kw_to_pos[arg.name] = read_count pos_to_kw[read_count] = arg.name read_count += 1 @@ -862,7 +859,7 @@ class CallableKernel(InKernelCallable): # insert the assignees at the required positions assignee_write_count = -1 for i, arg in enumerate(self.subkernel.args): - if arg.is_output_only: + if arg.is_output and not arg.is_input: assignee = assignees[-assignee_write_count-1] parameters.insert(i, assignee) par_dtypes.insert(i, self.arg_id_to_dtype[assignee_write_count]) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index e311fcc0f..46d70c054 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -1923,34 +1923,50 @@ def get_direct_callee_kernels(kernel, callables_table, insn_ids=None,): # {{{ direction helper tools -def infer_args_are_output_only(kernel): +def infer_args_are_input_output(kernel): """ - Returns a copy of *kernel* with the attribute ``is_output_only`` set. + Returns a copy of *kernel* with the attributes ``is_input`` and + ``is_output`` of the arguments set. .. note:: - If the attribute ``is_output_only`` is not supplied from an user, then - infers it as an output argument if it is written at some point in the - kernel. + If the attribute ``is_output`` of an argument is not supplied from an + user, then it is inferred as an output argument if it is written at + some point in the kernel. + + If the attribute ``is_input`` of an argument of is not supplied from + an user, then it is inferred as an input argument if it is either read + at some point in the kernel or it is neither read nor written. """ from loopy.kernel.data import ArrayArg, ValueArg, ConstantArg, ImageArg new_args = [] for arg in kernel.args: if isinstance(arg, (ArrayArg, ImageArg, ValueArg)): - if arg.is_output_only is not None: - assert isinstance(arg.is_output_only, bool) - new_args.append(arg) + if arg.is_output is not None: + assert isinstance(arg.is_output, bool) else: if arg.name in kernel.get_written_variables(): - new_args.append(arg.copy(is_output_only=True)) + arg = arg.copy(is_output=True) + else: + arg = arg.copy(is_output=False) + + if arg.is_input is not None: + assert isinstance(arg.is_input, bool) + else: + if arg.name in kernel.get_read_variables() or ( + (arg.name not in kernel.get_read_variables()) and ( + arg.name not in kernel.get_written_variables())): + arg = arg.copy(is_input=True) else: - new_args.append(arg.copy(is_output_only=False)) + arg = arg.copy(is_input=False) elif isinstance(arg, ConstantArg): - new_args.append(arg) + pass else: raise NotImplementedError("Unkonwn argument type %s." % type(arg)) + new_args.append(arg) + return kernel.copy(args=new_args) # }}} diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index e0f4a79d7..05866a105 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -171,22 +171,8 @@ def register_callable_kernel(program, callee_kernel): # check to make sure that the variables with 'out' direction is equal to # the number of assigness in the callee kernel intructions. - expected_num_assignees = len([arg for arg in callee_kernel.args if - arg.name in callee_kernel.get_written_variables()]) - - # can only predict the range of actual number of parameters to a kernel - # call, as a variable intended for pure output can be read - expected_max_num_parameters = len([arg for arg in callee_kernel.args if - arg.name in callee_kernel.get_read_variables()]) + len( - [arg for arg in callee_kernel.args if arg.name not in - (callee_kernel.get_read_variables() | - callee_kernel.get_written_variables())]) - expected_min_num_parameters = len([arg for arg in callee_kernel.args if - arg.name in callee_kernel.get_read_variables() and arg.name not in - callee_kernel.get_written_variables()]) + len( - [arg for arg in callee_kernel.args if arg.name not in - (callee_kernel.get_read_variables() | - callee_kernel.get_written_variables())]) + expected_num_assignees = sum(arg.is_output for arg in callee_kernel.args) + expected_num_arguments = sum(arg.is_input for arg in callee_kernel.args) for in_knl_callable in program.callables_table.values(): if isinstance(in_knl_callable, CallableKernel): caller_kernel = in_knl_callable.subkernel @@ -204,19 +190,11 @@ def register_callable_kernel(program, callee_kernel): "match." % ( callee_kernel.name, insn.id)) if len(insn.expression.parameters+tuple( - kw_parameters.values())) > expected_max_num_parameters: + kw_parameters.values())) != expected_num_arguments: raise LoopyError("The number of" - " parameters in instruction '%s' exceed" - " the max. number of arguments possible" - " for the callee kernel '%s' => arg matching" - " not possible." - % (insn.id, callee_kernel.name)) - if len(insn.expression.parameters+tuple( - kw_parameters.values())) < expected_min_num_parameters: - raise LoopyError("The number of" - " parameters in instruction '%s' is less than" - " the min. number of arguments possible" - " for the callee kernel '%s' => arg matching" + " arguments in instruction '%s' do match" + " the number of input arguments in" + " the callee kernel '%s' => arg matching" " not possible." % (insn.id, callee_kernel.name)) @@ -409,7 +387,7 @@ def _inline_call_instruction(caller_kernel, callee_knl, instruction): assignee_pos = 0 parameter_pos = 0 for i, arg in enumerate(callee_knl.args): - if arg.is_output_only: + if arg.is_output: arg_map[arg.name] = assignees[assignee_pos] assignee_pos += 1 else: diff --git a/test/test_callables.py b/test/test_callables.py index ce6b89e36..a241b21f2 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -327,6 +327,9 @@ def test_multi_arg_array_call(ctx_factory): lp.Assignment(id="update", assignee=acc_i, expression=p.Variable("min")(acc_i, a_i), depends_on="init1,init2")], + [ + lp.GlobalArg('acc_i, index', is_input=False, is_output=True), + "..."], name="custom_argmin") argmin_kernel = lp.fix_parameters(argmin_kernel, n=n) @@ -403,21 +406,22 @@ def test_non_sub_array_refs_arguments(ctx_factory): from loopy.transform.callable import _match_caller_callee_argument_dimension_ callee = lp.make_function("{[i] : 0 <= i < 6}", "a[i] = a[i] + j", - [lp.GlobalArg("a", dtype="double", shape=(6,), is_output_only=False), + [lp.GlobalArg("a", dtype="double", shape=(6,), is_output=True, + is_input=True), lp.ValueArg("j", dtype="int")], name="callee") caller1 = lp.make_kernel("{[j] : 0 <= j < 2}", "a[:] = callee(a[:], b[0])", - [lp.GlobalArg("a", dtype="double", shape=(6, ), is_output_only=False), - lp.GlobalArg("b", dtype="double", shape=(1, ), is_output_only=False)], + [lp.GlobalArg("a", dtype="double", shape=(6, ), is_output=False), + lp.GlobalArg("b", dtype="double", shape=(1, ), is_output=False)], name="caller", target=lp.CTarget()) caller2 = lp.make_kernel("{[j] : 0 <= j < 2}", "a[:]=callee(a[:], 3.1415926)", [lp.GlobalArg("a", dtype="double", shape=(6, ), - is_output_only=False)], + is_output=False)], name="caller", target=lp.CTarget()) caller3 = lp.make_kernel("{[j] : 0 <= j < 2}", "a[:]=callee(a[:], kappa)", [lp.GlobalArg("a", dtype="double", shape=(6, ), - is_output_only=False), '...'], + is_output=False), '...'], name="caller", target=lp.CTarget()) registered = lp.register_callable_kernel(caller1, callee) @@ -582,7 +586,7 @@ def test_argument_matching_for_inplace_update(ctx_factory): knl = lp.register_callable_kernel(knl, twice) x = np.random.randn(10) - evt, (out, ) = knl(queue, np.copy(x)) + evt, (out, ) = knl(queue, x=np.copy(x)) assert np.allclose(2*x, out) -- GitLab From a37db7a463cbf32ee88a94a06283175aecb6f933 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 19 Sep 2019 20:21:01 -0500 Subject: [PATCH 06/26] fixes minor error in argument matching --- loopy/kernel/function_interface.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 4b2d18ec5..2b50a2dc9 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -859,10 +859,12 @@ class CallableKernel(InKernelCallable): # insert the assignees at the required positions assignee_write_count = -1 for i, arg in enumerate(self.subkernel.args): - if arg.is_output and not arg.is_input: - assignee = assignees[-assignee_write_count-1] - parameters.insert(i, assignee) - par_dtypes.insert(i, self.arg_id_to_dtype[assignee_write_count]) + if arg.is_output: + if not arg.is_input: + assignee = assignees[-assignee_write_count-1] + parameters.insert(i, assignee) + par_dtypes.insert(i, self.arg_id_to_dtype[assignee_write_count]) + assignee_write_count -= 1 # no type casting in array calls -- GitLab From ddbe1c97045b70446dab340b4a98ecaf139e3165 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 19 Sep 2019 20:22:33 -0500 Subject: [PATCH 07/26] check the validity of a kernel call more diligenltly --- loopy/transform/callable.py | 80 +++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 05866a105..2b888c21b 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -154,6 +154,84 @@ class _RegisterCalleeKernel(ImmutableRecord): return None +def subarrayrefs_are_equiv(sar1, sar2): + """ + Compares if two instance of :class:`loopy.symbolic.SubArrayRef`s point + to the same array region. + """ + if len(sar1.swept_inames) != len(sar2.swept_inames): + return False + + iname_map = dict(zip(sar1.swept_inames, sar2.swept_inames)) + + from pymbolic.mapper.substitutor import make_subst_func + from loopy.symbolic import SubstitutionMapper + sar1_substed = SubstitutionMapper(make_subst_func(iname_map))(sar1) + + return sar1_substed == sar2 + + +def _check_correctness_of_args_and_assignees(insn, callee_kernel): + from loopy.kernel.function_interface import get_kw_pos_association + kw_to_pos, pos_to_kw = get_kw_pos_association(callee_kernel) + callee_args_to_insn_params = [[] for _ in callee_kernel.args] + expr = insn.expression + from pymbolic.primitives import Call, CallWithKwargs + if isinstance(expr, Call): + expr = CallWithKwargs(expr.function, expr.parameters, kw_parameters={}) + for i, param in enumerate(expr.parameters): + pos = kw_to_pos[callee_kernel.args[i].name] + if pos < 0: + raise LoopyError("#{} argument meant for output obtained as an" + " input in '{}'.".format(i, insn)) + + assert pos == i + + callee_args_to_insn_params[i].append(param) + + for kw, param in six.iteritems(expr.kw_parameters): + pos = kw_to_pos[kw] + if pos < 0: + raise LoopyError("KW-argument '{}' meant for output obtained as an" + " input in '{}'.".format(kw, insn)) + callee_args_to_insn_params[pos].append(param) + + num_pure_assignees = 0 + for i, assignee in enumerate(insn.assignees): + pos = kw_to_pos[pos_to_kw[-i-1]] + + if pos < 0: + pos = (len(expr.parameters) + + len(expr.kw_parameters)+num_pure_assignees) + num_pure_assignees += 1 + + callee_args_to_insn_params[pos].append(assignee) + + # TODO: Some of the checks might be redundant. + + for arg, insn_params in zip(callee_kernel.args, + callee_args_to_insn_params): + if len(insn_params) == 1: + # making sure that the argument is either only input or output + if arg.is_input == arg.is_output: + raise LoopyError("Argument '{}' in '{}' should be passed in" + " both assignees and parameters in Call.".format( + insn_params[0], insn)) + elif len(insn_params) == 2: + if arg.is_input != arg.is_output: + raise LoopyError("Found multiple parameters mapping to an" + " argument which is not both input and output in" + " ''.".format()) + if not subarrayrefs_are_equiv(insn_params[0], insn_params[1]): + raise LoopyError("'{}' and '{}' point to the same argument in" + " the callee, but are unequal.".format( + insn_params[0], insn_params[1])) + else: + raise LoopyError("Multiple(>2) arguments pointing to the same" + " argument for '{}' in '{}'.".format(callee_kernel.name, + insn)) + + def register_callable_kernel(program, callee_kernel): """Returns a copy of *caller_kernel*, which would resolve *function_name* in an expression as a call to *callee_kernel*. @@ -198,6 +276,8 @@ def register_callable_kernel(program, callee_kernel): " not possible." % (insn.id, callee_kernel.name)) + _check_correctness_of_args_and_assignees(insn, callee_kernel) + elif isinstance(insn, (MultiAssignmentBase, CInstruction, _DataObliviousInstruction)): pass -- GitLab From ed9697621aec711d8d6b2b8c0e0b38a5699a34d9 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 19 Sep 2019 23:36:28 -0500 Subject: [PATCH 08/26] new enforcement of argument matching find some bugs in the tests! --- test/test_callables.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_callables.py b/test/test_callables.py index a241b21f2..4fe8735dc 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -260,19 +260,19 @@ def test_shape_translation_through_sub_array_ref(ctx_factory, inline): callee1 = lp.make_function( "{[i]: 0<=i<6}", """ - a[i] = 2*abs(b[i]) + b[i] = 2*abs(a[i]) """, name="callee_fn1") callee2 = lp.make_function( "{[i, j]: 0<=i<3 and 0 <= j < 2}", """ - a[i, j] = 3*b[i, j] + b[i, j] = 3*a[i, j] """, name="callee_fn2") callee3 = lp.make_function( "{[i]: 0<=i<6}", """ - a[i] = 5*b[i] + b[i] = 5*a[i] """, name="callee_fn3") knl = lp.make_kernel( @@ -328,6 +328,7 @@ def test_multi_arg_array_call(ctx_factory): expression=p.Variable("min")(acc_i, a_i), depends_on="init1,init2")], [ + lp.GlobalArg('a'), lp.GlobalArg('acc_i, index', is_input=False, is_output=True), "..."], name="custom_argmin") -- GitLab From 89efdfc96376c4bb9786f7464b5868e47447a918 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 20 Sep 2019 02:32:37 -0500 Subject: [PATCH 09/26] Fixes SubArrayRef.get_begin_subscript(..) - Fixed all the places where it was invoked. - get_begin_subscript(..) should be only called when generating code, so made sure that it is not being called at unnecessary places in :mod:`loopy`. --- loopy/kernel/instruction.py | 3 ++- loopy/symbolic.py | 21 ++++++++++++++------- loopy/target/c/codegen/expression.py | 2 +- loopy/transform/callable.py | 3 ++- loopy/type_inference.py | 2 +- 5 files changed, 20 insertions(+), 11 deletions(-) diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 1ba0dc7ec..97d0931bd 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -543,7 +543,8 @@ def _get_assignee_subscript_deps(expr): elif isinstance(expr, LinearSubscript): return get_dependencies(expr.index) elif isinstance(expr, SubArrayRef): - return get_dependencies(expr.get_begin_subscript().index) + return get_dependencies(expr.subscript.index) - ( + frozenset(iname.name for iname in expr.swept_inames)) else: raise RuntimeError("invalid lvalue '%s'" % expr) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 870f9fc2c..53d8d4431 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -198,7 +198,9 @@ class CombineMapper(CombineMapperBase): return self.rec(expr.expr, *args, **kwargs) def map_sub_array_ref(self, expr): - return self.rec(expr.get_begin_subscript()) + return self.combine(( + self.rec(expr.subscript), + self.combine(tuple(self.rec(idx) for idx in expr.swept_inames)))) map_linear_subscript = CombineMapperBase.map_subscript @@ -353,9 +355,9 @@ class DependencyMapper(DependencyMapperBase): def map_loopy_function_identifier(self, expr, *args, **kwargs): return set() - def map_sub_array_ref(self, expr, *args): - deps = self.rec(expr.subscript, *args) - return deps - set(iname for iname in expr.swept_inames) + def map_sub_array_ref(self, expr, *args, **kwargs): + deps = self.rec(expr.subscript, *args, **kwargs) + return deps - set(expr.swept_inames) map_linear_subscript = DependencyMapperBase.map_subscript @@ -845,7 +847,7 @@ class SubArrayRef(LoopyExpressionBase): self.swept_inames = swept_inames self.subscript = subscript - def get_begin_subscript(self): + def get_begin_subscript(self, kernel): """ Returns an instance of :class:`pymbolic.primitives.Subscript`, the beginning subscript of the array swept by the *SubArrayRef*. @@ -853,9 +855,14 @@ class SubArrayRef(LoopyExpressionBase): **Example:** Consider ``[i, k]: a[i, j, k, l]``. The beginning subscript would be ``a[0, j, 0, l]`` """ - # TODO: Set the zero to the minimum value of the iname. + + def _get_lower_bound(iname): + pwaff = kernel.get_iname_bounds(iname).lower_bound_pw_aff + return int(pw_aff_to_expr(pwaff)) + swept_inames_to_zeros = dict( - (swept_iname.name, 0) for swept_iname in self.swept_inames) + (swept_iname.name, _get_lower_bound(swept_iname.name)) for + swept_iname in self.swept_inames) return EvaluatorWithDeficientContext(swept_inames_to_zeros)( self.subscript) diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index c970901b1..5a066ddfb 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -167,7 +167,7 @@ class ExpressionToCExpressionMapper(IdentityMapper): return var(expr.name) def map_sub_array_ref(self, expr, type_context): - return var("&")(self.rec(expr.get_begin_subscript(), + return var("&")(self.rec(expr.get_begin_subscript(self.kernel), type_context)) def map_subscript(self, expr, type_context): diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 2b888c21b..56fab7561 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -368,7 +368,8 @@ class KernelInliner(SubstitutionMapper): "constant shape.".format(callee_arg)) flatten_index = 0 - for i, idx in enumerate(sar.get_begin_subscript().index_tuple): + for i, idx in enumerate(sar.get_begin_subscript( + self.caller).index_tuple): flatten_index += idx*caller_arg.dim_tags[i].stride flatten_index += sum( diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 281dcb43d..0d4430e0d 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -692,7 +692,7 @@ class TypeInferenceMapper(CombineMapper): for rec_result in rec_results] def map_sub_array_ref(self, expr): - return self.rec(expr.get_begin_subscript()) + return self.rec(expr.subscript) # }}} -- GitLab From 50250d247d38606cf33c3948c474d063d407d034 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 20 Sep 2019 02:36:13 -0500 Subject: [PATCH 10/26] minor fixes in the tests; test for a bug when the start of the swept iname is non zero --- test/test_callables.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/test/test_callables.py b/test/test_callables.py index 4fe8735dc..04eeae666 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -364,13 +364,13 @@ def test_packing_unpacking(ctx_factory, inline): callee1 = lp.make_function( "{[i]: 0<=i<6}", """ - a[i] = 2*b[i] + b[i] = 2*a[i] """, name="callee_fn1") callee2 = lp.make_function( "{[i, j]: 0<=i<2 and 0 <= j < 3}", """ - a[i, j] = 3*b[i, j] + b[i, j] = 3*a[i, j] """, name="callee_fn2") knl = lp.make_kernel( @@ -456,8 +456,7 @@ def test_empty_sub_array_refs(ctx_factory, inline): callee = lp.make_function( "{[d]:0<=d<1}", """ - a[d] = b[d] - c[d] - + c[d] = a[d] - b[d] """, name='wence_function') caller = lp.make_kernel("{[i]: 0<=i<10}", @@ -592,6 +591,29 @@ def test_argument_matching_for_inplace_update(ctx_factory): assert np.allclose(2*x, out) +def test_non_zero_start_in_subarray_ref(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + twice = lp.make_function( + "{[i]: 0<=i<10}", + """ + b[i] = 2*a[i] + """, name='twice') + + knl = lp.make_kernel( + "{[i, j]: -5<=i<5 and 0<=j<10}", + """ + [i]:y[i+5] = twice([j]: x[j]) + """, [lp.GlobalArg('x, y', shape=(10,), dtype=np.float64)]) + + knl = lp.register_callable_kernel(knl, twice) + + x = np.random.randn(10) + evt, (out, ) = knl(queue, x=np.copy(x)) + + assert np.allclose(2*x, out) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab From 74c049694a0e76ff0980cb1fa6595cdfe3c6516f Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 20 Sep 2019 02:38:08 -0500 Subject: [PATCH 11/26] correctly checks if 2 sub array refs refer to the same part of arrays --- loopy/transform/callable.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 56fab7561..9c05dc97f 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -154,24 +154,20 @@ class _RegisterCalleeKernel(ImmutableRecord): return None -def subarrayrefs_are_equiv(sar1, sar2): +def subarrayrefs_are_equiv(sar1, sar2, knl): """ Compares if two instance of :class:`loopy.symbolic.SubArrayRef`s point to the same array region. """ - if len(sar1.swept_inames) != len(sar2.swept_inames): - return False - - iname_map = dict(zip(sar1.swept_inames, sar2.swept_inames)) - - from pymbolic.mapper.substitutor import make_subst_func - from loopy.symbolic import SubstitutionMapper - sar1_substed = SubstitutionMapper(make_subst_func(iname_map))(sar1) + from loopy.kernel.function_interface import get_arg_descriptor_for_expression - return sar1_substed == sar2 + return get_arg_descriptor_for_expression(knl, sar1) == ( + get_arg_descriptor_for_expression(knl, sar2)) and ( + sar1.get_begin_subscript(knl) == + sar2.get_begin_subscript(knl)) -def _check_correctness_of_args_and_assignees(insn, callee_kernel): +def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): from loopy.kernel.function_interface import get_kw_pos_association kw_to_pos, pos_to_kw = get_kw_pos_association(callee_kernel) callee_args_to_insn_params = [[] for _ in callee_kernel.args] @@ -222,7 +218,8 @@ def _check_correctness_of_args_and_assignees(insn, callee_kernel): raise LoopyError("Found multiple parameters mapping to an" " argument which is not both input and output in" " ''.".format()) - if not subarrayrefs_are_equiv(insn_params[0], insn_params[1]): + if not subarrayrefs_are_equiv(insn_params[0], insn_params[1], + caller_knl): raise LoopyError("'{}' and '{}' point to the same argument in" " the callee, but are unequal.".format( insn_params[0], insn_params[1])) @@ -276,7 +273,8 @@ def register_callable_kernel(program, callee_kernel): " not possible." % (insn.id, callee_kernel.name)) - _check_correctness_of_args_and_assignees(insn, callee_kernel) + _check_correctness_of_args_and_assignees(insn, + callee_kernel, caller_kernel) elif isinstance(insn, (MultiAssignmentBase, CInstruction, _DataObliviousInstruction)): -- GitLab From 51b25d2a029bfa7d554a83f5d0f286b2dc476aaa Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 23 Sep 2019 05:11:22 -0500 Subject: [PATCH 12/26] minor fixes from the review --- loopy/kernel/data.py | 2 +- loopy/kernel/tools.py | 5 +++++ loopy/transform/callable.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 15a77b809..51367e64e 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -360,7 +360,7 @@ class ArrayArg(ArrayBase, KernelArgument): .. attribute:: is_input An instance of :class:`bool`. If set to *True*, expected to be - provided by the user. + provided by the caller. """) allowed_extra_kwargs = [ diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 46d70c054..d0e4ef084 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -1965,6 +1965,11 @@ def infer_args_are_input_output(kernel): else: raise NotImplementedError("Unkonwn argument type %s." % type(arg)) + if not (arg.is_input or arg.is_output): + raise LoopyError("Kernel argument must be either input or output." + " '{}' in '{}' does not follow it.".format(arg.name, + kernel.name)) + new_args.append(arg) return kernel.copy(args=new_args) diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 9c05dc97f..a87a43f4e 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -267,7 +267,7 @@ def register_callable_kernel(program, callee_kernel): if len(insn.expression.parameters+tuple( kw_parameters.values())) != expected_num_arguments: raise LoopyError("The number of" - " arguments in instruction '%s' do match" + " arguments in instruction '%s' do not match" " the number of input arguments in" " the callee kernel '%s' => arg matching" " not possible." -- GitLab From 7b4771017af6ba16b2198b01b17d66d97c528573 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 23 Sep 2019 05:26:39 -0500 Subject: [PATCH 13/26] rephrasing is_output docs --- loopy/kernel/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 51367e64e..f0d7b3789 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -354,8 +354,8 @@ class ArrayArg(ArrayBase, KernelArgument): .. attribute:: is_output - An instance of :class:`bool`. If set to *True*, recorded to be - returned from the kernel. + An instance of :class:`bool`. If set to *True*, the argument is used + to return information to the caller .. attribute:: is_input -- GitLab From 266fea05eeb4bf3c082cad5d313f8a9d97684c28 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 1 Oct 2019 22:09:35 -0500 Subject: [PATCH 14/26] Removes support for "return_list_of_knl" in parse_fortran --- loopy/frontend/fortran/__init__.py | 9 ++------- loopy/ipython_ext.py | 2 +- test/test_fortran.py | 7 +++---- test/test_numa_diff.py | 2 +- 4 files changed, 7 insertions(+), 13 deletions(-) diff --git a/loopy/frontend/fortran/__init__.py b/loopy/frontend/fortran/__init__.py index bc360b996..9b63c10f8 100644 --- a/loopy/frontend/fortran/__init__.py +++ b/loopy/frontend/fortran/__init__.py @@ -296,11 +296,9 @@ def _add_assignees_to_calls(knl, all_kernels): def parse_fortran(source, filename="", free_form=None, strict=None, - seq_dependencies=None, auto_dependencies=None, target=None, - return_list_of_knls=False): + seq_dependencies=None, auto_dependencies=None, target=None): """ - :returns: an instance of :class:`list` of :class:`loopy.LoopKernel`s if - *return_list_of_knls* is True else a :class:`loopy.Program`. + :returns: A :class:`loopy.Program`. """ parse_plog = ProcessLogger(logger, "parsing fortran file '%s'" % filename) @@ -342,9 +340,6 @@ def parse_fortran(source, filename="", free_form=None, strict=None, kernels = f2loopy.make_kernels(seq_dependencies=seq_dependencies) - if return_list_of_knls: - return kernels - kernels = [_add_assignees_to_calls(knl, kernels) for knl in kernels] from loopy.kernel.tools import identify_root_kernel diff --git a/loopy/ipython_ext.py b/loopy/ipython_ext.py index e44b183ed..ec1b10f1f 100644 --- a/loopy/ipython_ext.py +++ b/loopy/ipython_ext.py @@ -9,7 +9,7 @@ import loopy as lp class LoopyMagics(Magics): @cell_magic def fortran_kernel(self, line, cell): - result = lp.parse_fortran(cell, return_list_of_knls=True) + result = lp.parse_fortran(cell) for knl in result: self.shell.user_ns[knl.name] = knl diff --git a/test/test_fortran.py b/test/test_fortran.py index 1ab28409b..856d85c49 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -534,10 +534,9 @@ def test_parse_and_fuse_two_kernels(): !$loopy begin ! ! # FIXME: correct this after the "Module" is done. - ! # prg = lp.parse_fortran(SOURCE) - ! # fill = prg["fill"] - ! # twice = prg["twice"] - ! fill, twice = lp.parse_fortran(SOURCE, return_list_of_knls=True) + ! prg = lp.parse_fortran(SOURCE) + ! fill = prg["fill"] + ! twice = prg["twice"] ! knl = lp.fuse_kernels((fill, twice)) ! print(knl) ! RESULT = knl diff --git a/test/test_numa_diff.py b/test/test_numa_diff.py index 55a2d2e11..de0bcf70a 100644 --- a/test/test_numa_diff.py +++ b/test/test_numa_diff.py @@ -61,7 +61,7 @@ def test_gnuma_horiz_kernel(ctx_factory, ilp_multiple, Nq, opt_level): # noqa hsv_r, hsv_s = [ knl for knl in lp.parse_fortran(source, filename, - seq_dependencies=False, return_list_of_knls=True) + seq_dependencies=False) if "KernelR" in knl.name or "KernelS" in knl.name ] hsv_r = lp.tag_instructions(hsv_r, "rknl") -- GitLab From 435155d5b0a1134adc0cd93f678489a506bcd6c6 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 1 Oct 2019 22:35:15 -0500 Subject: [PATCH 15/26] deprecates is_output_only --- loopy/kernel/data.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index f0d7b3789..c1acd5069 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -371,8 +371,16 @@ class ArrayArg(ArrayBase, KernelArgument): def __init__(self, *args, **kwargs): if "address_space" not in kwargs: raise TypeError("'address_space' must be specified") - kwargs["is_output"] = kwargs.pop("is_output", None) - kwargs["is_input"] = kwargs.pop("is_input", None) + + is_output_only = kwargs.pop("is_output_only", None) + if is_output_only is not None: + warn("'is_output_only' is deprecated. Use 'is_output', 'is_input'" + " instead.", DeprecationWarning, stacklevel=2) + kwargs["is_output"] = is_output_only + kwargs["is_input"] = not is_output_only + else: + kwargs["is_output"] = kwargs.pop("is_output", None) + kwargs["is_input"] = kwargs.pop("is_input", None) super(ArrayArg, self).__init__(*args, **kwargs) -- GitLab From 63979735f675e2d76033cb1e6177ee9d0187cd87 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 2 Oct 2019 11:30:04 -0500 Subject: [PATCH 16/26] handles minor docs issues --- loopy/kernel/data.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index c1acd5069..0d74b7248 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -355,12 +355,14 @@ class ArrayArg(ArrayBase, KernelArgument): .. attribute:: is_output An instance of :class:`bool`. If set to *True*, the argument is used - to return information to the caller + to return information to the caller. If set to *False*, then the + callee should not write the array during execution. .. attribute:: is_input An instance of :class:`bool`. If set to *True*, expected to be - provided by the caller. + provided by the caller. If *False* then the callee should not depend + on the state of the array on entry to a function. """) allowed_extra_kwargs = [ -- GitLab From 6f177eb923b01e7e1e3c789f83fe2ce347387e9b Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 2 Oct 2019 12:30:51 -0500 Subject: [PATCH 17/26] minor rewording in comments/error strings --- loopy/transform/callable.py | 36 +++++++++++------------------------- test/test_fortran.py | 2 +- 2 files changed, 12 insertions(+), 26 deletions(-) diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index a87a43f4e..2cde66767 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -154,19 +154,6 @@ class _RegisterCalleeKernel(ImmutableRecord): return None -def subarrayrefs_are_equiv(sar1, sar2, knl): - """ - Compares if two instance of :class:`loopy.symbolic.SubArrayRef`s point - to the same array region. - """ - from loopy.kernel.function_interface import get_arg_descriptor_for_expression - - return get_arg_descriptor_for_expression(knl, sar1) == ( - get_arg_descriptor_for_expression(knl, sar2)) and ( - sar1.get_begin_subscript(knl) == - sar2.get_begin_subscript(knl)) - - def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): from loopy.kernel.function_interface import get_kw_pos_association kw_to_pos, pos_to_kw = get_kw_pos_association(callee_kernel) @@ -178,8 +165,8 @@ def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): for i, param in enumerate(expr.parameters): pos = kw_to_pos[callee_kernel.args[i].name] if pos < 0: - raise LoopyError("#{} argument meant for output obtained as an" - " input in '{}'.".format(i, insn)) + raise LoopyError("#{}(1-based) argument meant for output obtained as an" + " input in '{}'.".format(i+1, insn)) assert pos == i @@ -188,7 +175,7 @@ def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): for kw, param in six.iteritems(expr.kw_parameters): pos = kw_to_pos[kw] if pos < 0: - raise LoopyError("KW-argument '{}' meant for output obtained as an" + raise LoopyError("Keyword argument '{}' meant for output obtained as an" " input in '{}'.".format(kw, insn)) callee_args_to_insn_params[pos].append(param) @@ -203,8 +190,6 @@ def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): callee_args_to_insn_params[pos].append(assignee) - # TODO: Some of the checks might be redundant. - for arg, insn_params in zip(callee_kernel.args, callee_args_to_insn_params): if len(insn_params) == 1: @@ -218,14 +203,15 @@ def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): raise LoopyError("Found multiple parameters mapping to an" " argument which is not both input and output in" " ''.".format()) - if not subarrayrefs_are_equiv(insn_params[0], insn_params[1], - caller_knl): - raise LoopyError("'{}' and '{}' point to the same argument in" - " the callee, but are unequal.".format( - insn_params[0], insn_params[1])) + if insn_params[0] != insn_params[1]: + raise LoopyError("Unequal SubArrayRefs '{}', '{}' passed as '{}'" + " to '{}'.".format(insn_params[0], insn_params[1], + arg.name, callee_kernel.name)) else: - raise LoopyError("Multiple(>2) arguments pointing to the same" - " argument for '{}' in '{}'.".format(callee_kernel.name, + # repitition due incorrect usage of kwargs and + # positional args + raise LoopyError("Multiple(>2) arguments obtained for" + " '{}' in '{}'.".format(callee_kernel.name, insn)) diff --git a/test/test_fortran.py b/test/test_fortran.py index 856d85c49..c6b7e8e37 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -533,7 +533,7 @@ def test_parse_and_fuse_two_kernels(): !$loopy begin ! - ! # FIXME: correct this after the "Module" is done. + ! # FIXME: correct this after the "TranslationUnit" is done. ! prg = lp.parse_fortran(SOURCE) ! fill = prg["fill"] ! twice = prg["twice"] -- GitLab From 2dadb47d8c45c1a068316bdcbefdedbd1ca4071d Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 2 Oct 2019 12:31:42 -0500 Subject: [PATCH 18/26] cache the results of slice->SAR during the processing of an instruction --- loopy/kernel/creation.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 4be7e06b8..5582b0c63 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1888,9 +1888,18 @@ class SliceToInameReplacer(IdentityMapper): self.var_name_gen = var_name_gen self.knl = knl + # caching to map equivalent slices to equivalent SubArrayRefs + self.cache = {} + self.subarray_ref_bounds = [] + def clear_cache(self): + self.cache = {} + def map_subscript(self, expr): + if expr in self.cache: + return self.cache[expr] + subscript_iname_bounds = {} self.subarray_ref_bounds.append(subscript_iname_bounds) @@ -1919,11 +1928,15 @@ class SliceToInameReplacer(IdentityMapper): new_index.append(index) if swept_inames: - return SubArrayRef(tuple(swept_inames), Subscript( + result = SubArrayRef(tuple(swept_inames), Subscript( self.rec(expr.aggregate), self.rec(tuple(new_index)))) else: - return IdentityMapper.map_subscript(self, expr) + result = IdentityMapper.map_subscript(self, expr) + + self.cache[expr] = result + + return result def map_call(self, expr): def _convert_array_to_slices(arg): @@ -2014,6 +2027,8 @@ def realize_slices_array_inputs_as_sub_array_refs(kernel): raise NotImplementedError("Unknown type of instruction -- %s" % type(insn)) + slice_replacer.clear_cache() + return kernel.copy( domains=( kernel.domains -- GitLab From 584c4d0de273295c320694ced999f7bf01ba4301 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 2 Oct 2019 13:16:16 -0500 Subject: [PATCH 19/26] minor docs fix --- loopy/kernel/tools.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index d0e4ef084..7dfe4f48b 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -1930,13 +1930,13 @@ def infer_args_are_input_output(kernel): .. note:: - If the attribute ``is_output`` of an argument is not supplied from an - user, then it is inferred as an output argument if it is written at + If the :attr:`~loopy.ArrayArg.is_output` is not supplied from a user, + then the array is inferred as an output argument if it is written at some point in the kernel. - If the attribute ``is_input`` of an argument of is not supplied from - an user, then it is inferred as an input argument if it is either read - at some point in the kernel or it is neither read nor written. + If the :attr:`~loopy.ArrayArg.is_input` is not supplied from a user, + then the array is inferred as an input argument if it is either read at + some point in the kernel or it is neither read nor written. """ from loopy.kernel.data import ArrayArg, ValueArg, ConstantArg, ImageArg new_args = [] -- GitLab From 44d4c497b3aa22f07ca004b7c97e7860297bbf6e Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 2 Oct 2019 13:31:47 -0500 Subject: [PATCH 20/26] fuse_kernel should take in LoopKernels --- loopy/transform/fusion.py | 150 ++++++++++++-------------------------- 1 file changed, 45 insertions(+), 105 deletions(-) diff --git a/loopy/transform/fusion.py b/loopy/transform/fusion.py index 45e9c0a06..287c810e2 100644 --- a/loopy/transform/fusion.py +++ b/loopy/transform/fusion.py @@ -32,8 +32,6 @@ from loopy.diagnostic import LoopyError from pymbolic import var from loopy.kernel import LoopKernel -from loopy.kernel.function_interface import CallableKernel -from loopy.program import rename_resolved_functions_in_a_single_kernel def _apply_renames_in_exprs(kernel, var_renames): @@ -291,7 +289,51 @@ def _fuse_two_kernels(knla, knlb): # }}} -def fuse_loop_kernels(kernels, suffixes=None, data_flow=None): +def fuse_kernels(kernels, suffixes=None, data_flow=None): + """Return a kernel that performs all the operations in all entries + of *kernels*. + + :arg kernels: A list of :class:`loopy.LoopKernel` instances to be fused. + :arg suffixes: If given, must be a list of strings of a length matching + that of *kernels*. This will be used to disambiguate the names + of temporaries, as described below. + :arg data_flow: A list of data dependencies + ``[(var_name, from_kernel, to_kernel), ...]``. + Based on this, the fuser will create dependencies between all + writers of *var_name* in ``kernels[from_kernel]`` to + readers of *var_name* in ``kernels[to_kernel]``. + *from_kernel* and *to_kernel* are indices into *kernels*. + + The components of the kernels are fused as follows: + + * The resulting kernel will have a domain involving all the inames + and parameters occurring across *kernels*. + Inames with matching names across *kernels* are fused in such a way + that they remain a single iname in the fused kernel. + Use :func:`loopy.rename_iname` if this is not desired. + + * The projection of the domains of each pair of kernels onto their + common subset of inames must match in order for fusion to + succeed. + + * Assumptions are fused by taking their conjunction. + + * If kernel arguments with matching names are encountered across + *kernels*, their declarations must match in order for fusion to + succeed. + + * Temporaries are automatically renamed to remain uniquely associated + with each instruction stream. + + * The resulting kernel will contain all instructions from each entry + of *kernels*. Clashing instruction IDs will be renamed to ensure + uniqueness. + + .. versionchanged:: 2016.2 + + *data_flow* was added in version 2016.2 + """ + assert all(isinstance(knl, LoopKernel) for knl in kernels) kernels = list(kernels) @@ -373,106 +415,4 @@ def fuse_loop_kernels(kernels, suffixes=None, data_flow=None): return result - -def fuse_kernels(programs, suffixes=None, data_flow=None): - """Return a kernel that performs all the operations in all entries - of *kernels*. - - :arg kernels: A list of :class:`loopy.LoopKernel` instances to be fused. - :arg suffixes: If given, must be a list of strings of a length matching - that of *kernels*. This will be used to disambiguate the names - of temporaries, as described below. - :arg data_flow: A list of data dependencies - ``[(var_name, from_kernel, to_kernel), ...]``. - Based on this, the fuser will create dependencies between all - writers of *var_name* in ``kernels[from_kernel]`` to - readers of *var_name* in ``kernels[to_kernel]``. - *from_kernel* and *to_kernel* are indices into *kernels*. - - The components of the kernels are fused as follows: - - * The resulting kernel will have a domain involving all the inames - and parameters occurring across *kernels*. - Inames with matching names across *kernels* are fused in such a way - that they remain a single iname in the fused kernel. - Use :func:`loopy.rename_iname` if this is not desired. - - * The projection of the domains of each pair of kernels onto their - common subset of inames must match in order for fusion to - succeed. - - * Assumptions are fused by taking their conjunction. - - * If kernel arguments with matching names are encountered across - *kernels*, their declarations must match in order for fusion to - succeed. - - * Temporaries are automatically renamed to remain uniquely associated - with each instruction stream. - - * The resulting kernel will contain all instructions from each entry - of *kernels*. Clashing instruction IDs will be renamed to ensure - uniqueness. - - .. versionchanged:: 2016.2 - - *data_flow* was added in version 2016.2 - """ - - from loopy.program import make_program - - programs = [make_program(knl) if isinstance(knl, LoopKernel) else knl for - knl in programs] - - # all the resolved functions in programs must be registered in - # main_callables_table - main_prog_callables_info = ( - programs[0].callables_table) - old_root_kernel_callable = ( - programs[0].callables_table[programs[0].name]) - kernels = [programs[0].root_kernel] - - # removing the callable collisions that maybe present - for prog in programs[1:]: - root_kernel = prog.root_kernel - renames_needed = {} - for old_func_id, in_knl_callable in prog.callables_table.items(): - if isinstance(in_knl_callable, CallableKernel): - # Fusing programs with multiple callable kernels is tough. - # Reason: Need to first figure out the order in which the - # callable kernels must be resolved into - # main_callables_table, because of renaming is - # needed to be done in the callable kernels before registering. - # Hence disabling it until required. - if in_knl_callable.subkernel.name != prog.name: - raise LoopyError("fuse_kernels cannot fuse programs with " - "multiple callable kernels.") - - # root kernel are dealt at the end after performing all the - # renaming. - continue - main_prog_callables_info, new_func_id = ( - main_prog_callables_info.with_added_callable(var(old_func_id), - in_knl_callable)) - - if old_func_id != new_func_id: - renames_needed[old_func_id] = new_func_id - - if renames_needed: - root_kernel = rename_resolved_functions_in_a_single_kernel( - root_kernel, renames_needed) - - kernels.append(root_kernel) - - new_root_kernel = fuse_loop_kernels(kernels, suffixes, data_flow) - new_root_kernel_callable = old_root_kernel_callable.copy( - subkernel=new_root_kernel.copy(name=programs[0].name)) - - # TODO: change the name of the final root kernel. - main_prog_callables_info, _ = main_prog_callables_info.with_added_callable( - var(programs[0].name), new_root_kernel_callable) - - return programs[0].copy( - callables_table=main_prog_callables_info) - # vim: foldmethod=marker -- GitLab From 71b05d5be15c38b4534dfdc92d056ebb6bfbf44a Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 2 Oct 2019 14:36:51 -0500 Subject: [PATCH 21/26] way better docs for _check_correctness_of_args_and_assignees --- loopy/transform/callable.py | 72 +++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 2cde66767..2fb9168ec 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -154,14 +154,29 @@ class _RegisterCalleeKernel(ImmutableRecord): return None -def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): +def _check_correctness_of_args_and_assignees(insn, callee_kernel): + """ + Checks that -- + 1. the call in *insn* agrees the :attr:`~loopy.ArrayArg.is_input` and + :attr:`~loopy.ArrayArg.is_output` for the corresponding arguments in + *callee_kernel*, + 2. the call does not get multiple values for a keyword argument, + 3. only the arguments that are both output and input appear in the + assignees as well as parameters in *insn*'s call. + """ from loopy.kernel.function_interface import get_kw_pos_association kw_to_pos, pos_to_kw = get_kw_pos_association(callee_kernel) + + # mapping from argument index in callee to the assignees/paramters mapping + # to it callee_args_to_insn_params = [[] for _ in callee_kernel.args] expr = insn.expression - from pymbolic.primitives import Call, CallWithKwargs + from pymbolic.primitives import Call if isinstance(expr, Call): expr = CallWithKwargs(expr.function, expr.parameters, kw_parameters={}) + + # {{{ check that call parameters are input arguments in callee + for i, param in enumerate(expr.parameters): pos = kw_to_pos[callee_kernel.args[i].name] if pos < 0: @@ -179,6 +194,20 @@ def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): " input in '{}'.".format(kw, insn)) callee_args_to_insn_params[pos].append(param) + # }}} + + # {{{ check that positional and Keyword arguments and positional do not map + # to the same callee arg + + if any(len(pars) >= 2 for pars in callee_args_to_insn_params): + raise LoopyError("{}() got multiple values for keyword argument" + " '{}'".format(callee_kernel.name, callee_kernel.args[i].name)) + + # }}} + + # {{{ check that only the args which are both input and output appear both + # in assignees and parameters + num_pure_assignees = 0 for i, assignee in enumerate(insn.assignees): pos = kw_to_pos[pos_to_kw[-i-1]] @@ -195,7 +224,7 @@ def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): if len(insn_params) == 1: # making sure that the argument is either only input or output if arg.is_input == arg.is_output: - raise LoopyError("Argument '{}' in '{}' should be passed in" + raise LoopyError("Parameter '{}' in '{}' should be passed in" " both assignees and parameters in Call.".format( insn_params[0], insn)) elif len(insn_params) == 2: @@ -208,11 +237,10 @@ def _check_correctness_of_args_and_assignees(insn, callee_kernel, caller_knl): " to '{}'.".format(insn_params[0], insn_params[1], arg.name, callee_kernel.name)) else: - # repitition due incorrect usage of kwargs and - # positional args - raise LoopyError("Multiple(>2) arguments obtained for" - " '{}' in '{}'.".format(callee_kernel.name, - insn)) + # should not reach here + assert False + + # }}} def register_callable_kernel(program, callee_kernel): @@ -230,37 +258,13 @@ def register_callable_kernel(program, callee_kernel): assert isinstance(callee_kernel, LoopKernel), ('{0} !=' '{1}'.format(type(callee_kernel), LoopKernel)) - # check to make sure that the variables with 'out' direction is equal to - # the number of assigness in the callee kernel intructions. - expected_num_assignees = sum(arg.is_output for arg in callee_kernel.args) - expected_num_arguments = sum(arg.is_input for arg in callee_kernel.args) for in_knl_callable in program.callables_table.values(): if isinstance(in_knl_callable, CallableKernel): caller_kernel = in_knl_callable.subkernel for insn in caller_kernel.instructions: if isinstance(insn, CallInstruction) and ( insn.expression.function.name == callee_kernel.name): - if isinstance(insn.expression, CallWithKwargs): - kw_parameters = insn.expression.kw_parameters - else: - kw_parameters = {} - if len(insn.assignees) != expected_num_assignees: - raise LoopyError("The number of arguments with 'out' " - "direction " "in callee kernel %s and the number " - "of assignees in " "instruction %s do not " - "match." % ( - callee_kernel.name, insn.id)) - if len(insn.expression.parameters+tuple( - kw_parameters.values())) != expected_num_arguments: - raise LoopyError("The number of" - " arguments in instruction '%s' do not match" - " the number of input arguments in" - " the callee kernel '%s' => arg matching" - " not possible." - % (insn.id, callee_kernel.name)) - - _check_correctness_of_args_and_assignees(insn, - callee_kernel, caller_kernel) + _check_correctness_of_args_and_assignees(insn, callee_kernel) elif isinstance(insn, (MultiAssignmentBase, CInstruction, _DataObliviousInstruction)): @@ -439,8 +443,6 @@ def _inline_call_instruction(caller_kernel, callee_knl, instruction): parameters = instruction.expression.parameters # reads # add keyword parameters - from pymbolic.primitives import CallWithKwargs - if isinstance(instruction.expression, CallWithKwargs): from loopy.kernel.function_interface import get_kw_pos_association -- GitLab From 1363a694946cad14db8d085eb3bb5bb709fa4bec Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 2 Oct 2019 14:45:19 -0500 Subject: [PATCH 22/26] SubArrayRef.begin_subscript -> get_start_subscript_from_sar --- loopy/symbolic.py | 41 ++++++++++++++-------------- loopy/target/c/codegen/expression.py | 3 +- loopy/transform/callable.py | 3 +- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 53d8d4431..6a664f60e 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -809,6 +809,27 @@ class SweptInameStrideCollector(CoefficientCollectorBase): return super(SweptInameStrideCollector, self).map_algebraic_leaf(expr) +def get_start_subscript_from_sar(sar, kernel): + """ + Returns an instance of :class:`pymbolic.primitives.Subscript`, the + beginning subscript of the array swept by the *SubArrayRef*. + + **Example:** Consider ``[i, k]: a[i, j, k, l]``. The beginning + subscript would be ``a[0, j, 0, l]`` + """ + + def _get_lower_bound(iname): + pwaff = kernel.get_iname_bounds(iname).lower_bound_pw_aff + return int(pw_aff_to_expr(pwaff)) + + swept_inames_to_zeros = dict( + (swept_iname.name, _get_lower_bound(swept_iname.name)) for + swept_iname in sar.swept_inames) + + return EvaluatorWithDeficientContext(swept_inames_to_zeros)( + sar.subscript) + + class SubArrayRef(LoopyExpressionBase): """ An algebraic expression to map an affine memory layout pattern (known as @@ -847,26 +868,6 @@ class SubArrayRef(LoopyExpressionBase): self.swept_inames = swept_inames self.subscript = subscript - def get_begin_subscript(self, kernel): - """ - Returns an instance of :class:`pymbolic.primitives.Subscript`, the - beginning subscript of the array swept by the *SubArrayRef*. - - **Example:** Consider ``[i, k]: a[i, j, k, l]``. The beginning - subscript would be ``a[0, j, 0, l]`` - """ - - def _get_lower_bound(iname): - pwaff = kernel.get_iname_bounds(iname).lower_bound_pw_aff - return int(pw_aff_to_expr(pwaff)) - - swept_inames_to_zeros = dict( - (swept_iname.name, _get_lower_bound(swept_iname.name)) for - swept_iname in self.swept_inames) - - return EvaluatorWithDeficientContext(swept_inames_to_zeros)( - self.subscript) - def __getinitargs__(self): return (self.swept_inames, self.subscript) diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index 5a066ddfb..b0bc187eb 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -167,7 +167,8 @@ class ExpressionToCExpressionMapper(IdentityMapper): return var(expr.name) def map_sub_array_ref(self, expr, type_context): - return var("&")(self.rec(expr.get_begin_subscript(self.kernel), + from loopy.symbolic import get_start_subscript_from_sar + return var("&")(self.rec(get_start_subscript_from_sar(expr, self.kernel), type_context)) def map_subscript(self, expr, type_context): diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 2fb9168ec..1bbdb1201 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -356,7 +356,8 @@ class KernelInliner(SubstitutionMapper): "constant shape.".format(callee_arg)) flatten_index = 0 - for i, idx in enumerate(sar.get_begin_subscript( + from loopy.symbolic import get_start_subscript_from_sar + for i, idx in enumerate(get_start_subscript_from_sar(sar, self.caller).index_tuple): flatten_index += idx*caller_arg.dim_tags[i].stride -- GitLab From 65c25393a2f8741dc39da9a7a34c85f70bd576c2 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 2 Oct 2019 14:50:37 -0500 Subject: [PATCH 23/26] better phrasing of comment --- loopy/kernel/function_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 2b50a2dc9..38beeaf44 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -231,8 +231,8 @@ def get_kw_pos_association(kernel): pos_to_kw[write_count] = arg.name write_count -= 1 if arg.is_input: - # if an argument is both input and output then the input is given - # more significance in kw_to_pos + # if an argument is both input and output then kw_to_pos is + # overwritten with its expected position in the parameters kw_to_pos[arg.name] = read_count pos_to_kw[read_count] = arg.name read_count += 1 -- GitLab From 9c660247cf42d095ebc02994d9661e797d617cd9 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 7 Oct 2019 02:34:44 -0500 Subject: [PATCH 24/26] no assumptions about is_output of args in fortran frontend --- loopy/frontend/fortran/translator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py index 949a3d4cc..caa8fa681 100644 --- a/loopy/frontend/fortran/translator.py +++ b/loopy/frontend/fortran/translator.py @@ -763,7 +763,6 @@ class F2LoopyTranslator(FTreeWalkerBase): arg_name, dtype=sub.get_type(arg_name), shape=sub.get_loopy_shape(arg_name), - is_output=False, )) else: kernel_data.append( -- GitLab From 3d5efe6c4b3fea49419ebc6396e7cd6a3d31b089 Mon Sep 17 00:00:00 2001 From: "[6~" Date: Wed, 23 Oct 2019 20:34:43 -0500 Subject: [PATCH 25/26] Program.__setstate__: reinstate _program_executor_cache --- loopy/program.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/loopy/program.py b/loopy/program.py index 1fb691531..c874d7b39 100644 --- a/loopy/program.py +++ b/loopy/program.py @@ -403,6 +403,12 @@ class Program(ImmutableRecord): strify_callable(clbl) for name, clbl in self.callables_table.items()) + + def __setstate__(self, state_obj): + super(Program, self).__setstate__(state_obj) + + self._program_executor_cache = {} + # }}} -- GitLab From 09052c072768684d0d4f870d553728f4c58db872 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 6 Apr 2020 19:40:42 -0500 Subject: [PATCH 26/26] merge leftover: handle is_input/is_output correctly --- loopy/kernel/tools.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 6120b41a1..ead996445 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -1942,9 +1942,8 @@ def infer_args_are_input_output(kernel): for arg in kernel.args: if isinstance(arg, ArrayArg): - if arg.is_output_only is not None: - assert isinstance(arg.is_output_only, bool) - new_args.append(arg) + if arg.is_output is not None: + assert isinstance(arg.is_output, bool) else: if arg.name in kernel.get_written_variables(): arg = arg.copy(is_output=True) @@ -1959,9 +1958,9 @@ def infer_args_are_input_output(kernel): arg.name not in kernel.get_written_variables())): arg = arg.copy(is_input=True) else: - new_args.append(arg.copy(is_output_only=False)) + arg = arg.copy(is_input=False) elif isinstance(arg, (ConstantArg, ImageArg, ValueArg)): - new_args.append(arg) + pass else: raise NotImplementedError("Unkonwn argument type %s." % type(arg)) -- GitLab