diff --git a/loopy/schedule/device_mapping.py b/loopy/schedule/device_mapping.py index 05aeb0436a96df9a33b08fb179dc41dec30c18c5..b1267bbe8201a5d49c8bec65f3fa07536b6b9476 100644 --- a/loopy/schedule/device_mapping.py +++ b/loopy/schedule/device_mapping.py @@ -26,19 +26,18 @@ from loopy.diagnostic import LoopyError def map_schedule_onto_host_or_device(kernel): - # Split the schedule onto host or device. - kernel = map_schedule_onto_host_or_device_impl(kernel) - if not kernel.target.split_kernel_at_global_barriers(): from loopy.schedule import CallKernel, ReturnFromKernel new_schedule = ( [CallKernel(kernel_name=kernel.name, extra_inames=[], - extra_temporaries=[])] + + extra_args=[])] + kernel.schedule + [ReturnFromKernel(kernel_name=kernel.name)]) return kernel.copy(schedule=new_schedule) + # Split the schedule onto host or device. + kernel = map_schedule_onto_host_or_device_impl(kernel) # Compute which temporaries and inames go into which kernel. kernel = restore_and_save_temporaries(kernel) return kernel @@ -120,6 +119,24 @@ def filter_temporaries(kernel, items): return result +def filter_scalars(kernel, items): + """ + Keep only the values in `items` which are scalars. + """ + from pymbolic.primitives import Subscript, Variable + result = set() + for item in items: + base = item + if isinstance(base, Subscript): + continue + if isinstance(base, Variable): + base = base.name + if base in kernel.temporary_variables and \ + len(kernel.temporary_variables[base].shape) == 0: + result.add(item) + return result + + def get_use_set(insn, include_subscripts=True): """ Return the use-set of the instruction, for liveness analysis. @@ -213,17 +230,11 @@ def compute_live_temporaries(kernel, schedule): elif isinstance(sched_item, RunInstruction): live_out[idx] = live_in[idx + 1] insn = id_to_insn[sched_item.insn_id] - # `defs` includes subscripts in liveness calculations, so that - # for code such as the following - # - # Loop i - # temp[i] := ... - # ... := f(temp[i]) - # End Loop - # - # the value temp[i] is not marked as live across the loop. - defs = filter_temporaries(kernel, get_def_set(insn)) - uses = filter_temporaries(kernel, get_use_set(insn)) + defs = filter_scalars(kernel, + filter_temporaries(kernel, + get_def_set(insn, include_subscripts=False))) + uses = filter_temporaries(kernel, + get_use_set(insn, include_subscripts=False)) live_in[idx] = (live_out[idx] - defs) | uses idx -= 1 @@ -241,6 +252,7 @@ def compute_live_temporaries(kernel, schedule): live_in = live_in[:-1] if 0: + print(kernel) print("Live-in values:") for i, li in enumerate(live_in): print("{}: {}".format(i, ", ".join(li))) @@ -271,11 +283,10 @@ def restore_and_save_temporaries(kernel): for idx, sched_item in enumerate(kernel.schedule): if isinstance(sched_item, CallKernel): inter_kernel_temporaries |= filter_out_subscripts(live_in[idx]) - call_count = 1 - elif isinstance(sched_item, ReturnFromKernel): - inter_kernel_temporaries |= filter_out_subscripts(live_out[idx]) + call_count += 1 if call_count == 1: + # XXX # Single kernel call - needs no saves / restores return kernel @@ -332,7 +343,8 @@ def restore_and_save_temporaries(kernel): assert temporary.base_storage is None, \ "Cannot promote temporaries with base_storage to global" - hw_inames = get_common_hw_inames(kernel, def_lists[temporary.name]) + hw_inames = get_common_hw_inames(kernel, + def_lists[temporary.name] + use_lists[temporary.name]) # This takes advantage of the fact that g < l in the alphabet :) hw_inames = sorted(hw_inames,