diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 5a736fd6b3b84993895a7dca5dd738fc91e5646d..5ad6903ff5670c24dc740aed6640632e8f84623f 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -45,8 +45,6 @@ from pyopencl.tools import context_dependent_memoize # {{{ preamble SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL// -#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable - #define WG_SIZE ${wg_size} /* SCAN_EXPR has no right know the indices it is scanning at because @@ -314,7 +312,7 @@ void ${name_prefix}_scan_intervals( } else %endif - sum = SCAN_EXPR(sum, tmp); + sum = SCAN_EXPR(tmp, sum); ldata[k][LID_0] = sum; } @@ -638,6 +636,16 @@ void ${name_prefix}_final_update( # {{{ helpers +def _process_code_for_macro(code): + if code.startswith("//CL//"): + code = code[6:] + + if "//" in code: + raise RuntimeError("end-of-line comments ('//') may not be used in " + "scan code snippets") + + return code.replace("\n", " \\\n") + def _round_down_to_power_of_2(val): result = 2**bitlog2(val) if result > val: @@ -697,7 +705,6 @@ _IGNORED_WORDS = set(""" pragma __attribute__ __global __kernel __local get_local_size get_local_id cl_khr_fp64 reqd_work_group_size get_num_groups barrier get_group_id - cl_khr_byte_addressable_store _final_update _scan_intervals @@ -790,7 +797,10 @@ class GenericScanKernel(object): (see :func:`pyopencl.tools.register_dtype`). :arg scan_expr: The associative, binary operation carrying out the scan, represented as a C string. Its two arguments are available as `a` - and `b` when it is evaluated. + and `b` when it is evaluated. `b` is guaranteed to be the + 'element being updated', and `a` is the increment. Thus, + if some data is supposed to just propagate along without being + modified by the scan, it should live in `b`. This expression may call functions given in the *preamble*. :arg input_expr: A C expression, encoded as a string, resulting @@ -869,7 +879,7 @@ class GenericScanKernel(object): self.is_segmented = is_segment_start_expr is not None if self.is_segmented: - is_segment_start_expr = is_segment_start_expr.replace("\n", " ") + is_segment_start_expr = _process_code_for_macro(is_segment_start_expr) use_lookbehind_update = "prev_item" in output_statement self.store_segment_start_flags = self.is_segmented and use_lookbehind_update @@ -895,8 +905,8 @@ class GenericScanKernel(object): scan_ctype=dtype_to_ctype(dtype), is_segmented=self.is_segmented, arg_ctypes=arg_ctypes, - scan_expr=scan_expr.replace("\n", " "), - neutral=neutral.replace("\n", " "), + scan_expr=_process_code_for_macro(scan_expr), + neutral=_process_code_for_macro(neutral), double_support=all( has_double_support(dev) for dev in devices), ) @@ -983,10 +993,10 @@ class GenericScanKernel(object): final_update_tpl = _make_template(UPDATE_SOURCE) final_update_src = str(final_update_tpl.render( wg_size=self.update_wg_size, - output_statement=output_statement.replace("\n", " "), - argument_signature=arguments.replace("\n", " "), + output_statement=_process_code_for_macro(output_statement), + argument_signature=_process_code_for_macro(arguments), is_segment_start_expr=is_segment_start_expr, - input_expr=input_expr.replace("\n", " "), + input_expr=_process_code_for_macro(input_expr), use_lookbehind_update=use_lookbehind_update, **self.code_variables)) @@ -1034,7 +1044,7 @@ class GenericScanKernel(object): wg_size=wg_size, input_expr=input_expr, k_group_size=k_group_size, - argument_signature=arguments.replace("\n", " "), + argument_signature=_process_code_for_macro(arguments), is_segment_start_expr=is_segment_start_expr, input_fetch_exprs=input_fetch_exprs, is_first_level=is_first_level, @@ -1362,7 +1372,8 @@ def _get_unique_kernel(ctx, dtype, is_equal_expr, scan_dtype, "%s %s" % (dtype_to_ctype(arg_dtype), name) for name, arg_dtype in extra_args_types] - key_expr_define = "#define IS_EQUAL_EXPR(a, b) %s\n" % is_equal_expr.replace("\n", " ") + key_expr_define = "#define IS_EQUAL_EXPR(a, b) %s\n" \ + % _process_code_for_macro(is_equal_expr) return GenericScanKernel( ctx, dtype, arguments=", ".join(arguments),