From f1012ba21a9b14d361619e7f362c9424e4333be9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 3 May 2015 22:08:22 -0500 Subject: [PATCH] Refactor code generation so that generating pure C function bodies is possible --- loopy/codegen/__init__.py | 111 ++++++++++----------------- loopy/kernel/data.py | 31 +------- loopy/target/c/__init__.py | 55 +++++++++++++ loopy/target/c/codegen/expression.py | 23 +----- loopy/target/opencl/__init__.py | 37 +++++++++ loopy/tools.py | 36 +++++++++ 6 files changed, 174 insertions(+), 119 deletions(-) diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 57393bd6d..ecbb55203 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -1,6 +1,4 @@ -from __future__ import division -from __future__ import absolute_import -import six +from __future__ import division, absolute_import __copyright__ = "Copyright (C) 2012 Andreas Kloeckner" @@ -24,6 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six from loopy.diagnostic import LoopyError from pytools import Record @@ -146,6 +145,26 @@ def add_comment(cmt, code): # }}} +class SeenFunction(Record): + """ + .. attribute:: name + .. attribute:: c_name + .. attribute:: arg_dtypes + + a tuple of arg dtypes + """ + + def __init__(self, name, c_name, arg_dtypes): + Record.__init__(self, + name=name, + c_name=c_name, + arg_dtypes=arg_dtypes) + + def __hash__(self): + return hash((type(self),) + + tuple((f, getattr(self, f)) for f in type(self).fields)) + + # {{{ code generation state class CodeGenerationState(object): @@ -371,30 +390,13 @@ def generate_code(kernel, device=None): from loopy.check import pre_codegen_checks pre_codegen_checks(kernel) - from cgen import (FunctionBody, FunctionDeclaration, - Value, Module, Block, - Line, Const, LiteralLines, Initializer) - logger.info("%s: generate code: start" % kernel.name) - from cgen.opencl import (CLKernel, CLRequiredWorkGroupSize) - - allow_complex = False - for var in kernel.args + list(six.itervalues(kernel.temporary_variables)): - if var.dtype.kind == "c": - allow_complex = True - - mod = [] - - seen_dtypes = set() - seen_functions = set() - - body = Block() - # {{{ examine arg list - from loopy.kernel.data import ImageArg, ValueArg + from loopy.kernel.data import ValueArg from loopy.kernel.array import ArrayBase + from cgen import Const impl_arg_info = [] @@ -417,30 +419,15 @@ def generate_code(kernel, device=None): else: raise ValueError("argument type not understood: '%s'" % type(arg)) - if any(isinstance(arg, ImageArg) for arg in kernel.args): - body.append(Initializer(Const(Value("sampler_t", "loopy_sampler")), - "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP " - "| CLK_FILTER_NEAREST")) + allow_complex = False + for var in kernel.args + list(six.itervalues(kernel.temporary_variables)): + if var.dtype.kind == "c": + allow_complex = True # }}} - mod.extend([ - LiteralLines(r""" - #define lid(N) ((%(idx_ctype)s) get_local_id(N)) - #define gid(N) ((%(idx_ctype)s) get_group_id(N)) - """ % dict(idx_ctype=kernel.target.dtype_to_typename(kernel.index_dtype))), - Line()]) - - # {{{ build lmem array declarators for temporary variables - - body.extend( - idi.cgen_declarator - for tv in six.itervalues(kernel.temporary_variables) - for idi in tv.decl_info( - kernel.target, - is_written=True, index_dtype=kernel.index_dtype)) - - # }}} + seen_dtypes = set() + seen_functions = set() initial_implemented_domain = isl.BasicSet.from_params(kernel.assumptions) codegen_state = CodeGenerationState( @@ -449,24 +436,12 @@ def generate_code(kernel, device=None): expression_to_code_mapper=kernel.target.get_expression_to_code_mapper( kernel, seen_dtypes, seen_functions, allow_complex)) - from loopy.codegen.loop import set_up_hw_parallel_loops - gen_code = set_up_hw_parallel_loops(kernel, 0, codegen_state) - - body.append(Line()) - - if isinstance(gen_code.ast, Block): - body.extend(gen_code.ast.contents) - else: - body.append(gen_code.ast) + code_str, implemented_domains = kernel.target.generate_code( + kernel, codegen_state, impl_arg_info) - mod.append( - FunctionBody( - CLRequiredWorkGroupSize( - kernel.get_grid_sizes_as_exprs()[1], - CLKernel(FunctionDeclaration( - Value("void", kernel.name), - [iai.cgen_declarator for iai in impl_arg_info]))), - body)) + from loopy.check import check_implemented_domains + assert check_implemented_domains(kernel, implemented_domains, + code_str) # {{{ handle preambles @@ -492,20 +467,18 @@ def generate_code(kernel, device=None): seen_preamble_tags.add(tag) dedup_preambles.append(preamble) - mod = ([LiteralLines(lines) for lines in dedup_preambles] - + [Line()] + mod) - - # }}} + from loopy.tools import remove_common_indentation + preamble_codes = [ + remove_common_indentation(lines) + "\n" + for lines in dedup_preambles] - result = str(Module(mod)) + code_str = "".join(preamble_codes) + code_str - from loopy.check import check_implemented_domains - assert check_implemented_domains(kernel, gen_code.implemented_domains, - result) + # }}} logger.info("%s: generate code: done" % kernel.name) - result = result, impl_arg_info + result = code_str, impl_arg_info if CACHING_ENABLED: code_gen_cache[input_kernel] = result diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index ab747b9a6..32c062689 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -747,34 +747,6 @@ class ExpressionInstruction(InstructionBase): # }}} -def _remove_common_indentation(code): - if "\n" not in code: - return code - - # accommodate pyopencl-ish syntax highlighting - code = code.lstrip("//CL//") - - if not code.startswith("\n"): - return code - - lines = code.split("\n") - while lines[0].strip() == "": - lines.pop(0) - while lines[-1].strip() == "": - lines.pop(-1) - - if lines: - base_indent = 0 - while lines[0][base_indent] in " \t": - base_indent += 1 - - for line in lines[1:]: - if line[:base_indent].strip(): - raise ValueError("inconsistent indentation") - - return "\n".join(line[base_indent:] for line in lines) - - # {{{ c instruction class CInstruction(InstructionBase): @@ -873,7 +845,8 @@ class CInstruction(InstructionBase): # }}} self.iname_exprs = new_iname_exprs - self.code = _remove_common_indentation(code) + from loopy.tools import remove_common_indentation + self.code = remove_common_indentation(code) self.read_variables = read_variables self.assignees = new_assignees diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index bb1427739..f74485e05 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -24,6 +24,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six + import numpy as np # noqa from loopy.target import TargetBase @@ -56,3 +58,56 @@ class CTarget(TargetBase): from loopy.target.c.codegen.expression import LoopyCCodeMapper return (LoopyCCodeMapper(kernel, seen_dtypes, seen_functions, allow_complex=allow_complex)) + + # {{{ code generation + + def generate_code(self, kernel, codegen_state, impl_arg_info): + from cgen import FunctionBody, FunctionDeclaration, Value, Module + + body, implemented_domains = kernel.target.generate_body( + kernel, codegen_state) + + mod = Module([ + FunctionBody( + kernel.target.wrap_function_declaration( + kernel, + FunctionDeclaration( + Value("void", kernel.name), + [iai.cgen_declarator for iai in impl_arg_info])), + body) + ]) + + return str(mod), implemented_domains + + def wrap_function_declaration(self, kernel, fdecl): + return fdecl + + def generate_body(self, kernel, codegen_state): + from cgen import Block + body = Block() + + # {{{ declare temporaries + + body.extend( + idi.cgen_declarator + for tv in six.itervalues(kernel.temporary_variables) + for idi in tv.decl_info( + kernel.target, + is_written=True, index_dtype=kernel.index_dtype)) + + # }}} + + from loopy.codegen.loop import set_up_hw_parallel_loops + gen_code = set_up_hw_parallel_loops(kernel, 0, codegen_state) + + from cgen import Line + body.append(Line()) + + if isinstance(gen_code.ast, Block): + body.extend(gen_code.ast.contents) + else: + body.append(gen_code.ast) + + return body, gen_code.implemented_domains + + # }}} diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index ccfa88040..1dc3aafcf 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -32,7 +32,6 @@ from pymbolic.mapper import RecursiveMapper from pymbolic.mapper.stringifier import (PREC_NONE, PREC_CALL, PREC_PRODUCT, PREC_POWER) import islpy as isl -from pytools import Record from loopy.expression import dtype_to_type_context, TypeInferenceMapper @@ -47,26 +46,6 @@ def get_opencl_vec_member(idx): return "s%s" % hex(int(idx))[2:] -class SeenFunction(Record): - """ - .. attribute:: name - .. attribute:: c_name - .. attribute:: arg_dtypes - - a tuple of arg dtypes - """ - - def __init__(self, name, c_name, arg_dtypes): - Record.__init__(self, - name=name, - c_name=c_name, - arg_dtypes=arg_dtypes) - - def __hash__(self): - return hash((type(self),) - + tuple((f, getattr(self, f)) for f in type(self).fields)) - - # {{{ C code mapper class LoopyCCodeMapper(RecursiveMapper): @@ -330,6 +309,7 @@ class LoopyCCodeMapper(RecursiveMapper): def seen_func(name): idt = self.kernel.index_dtype + from loopy.codegen import SeenFunction self.seen_functions.add(SeenFunction(name, name, (idt, idt))) if den_nonneg: @@ -451,6 +431,7 @@ class LoopyCCodeMapper(RecursiveMapper): "for function '%s' not understood" % identifier) + from loopy.codegen import SeenFunction self.seen_functions.add(SeenFunction(identifier, c_name, par_dtypes)) if str_parameters is None: # /!\ FIXME For some functions (e.g. 'sin'), it makes sense to diff --git a/loopy/target/opencl/__init__.py b/loopy/target/opencl/__init__.py index b297eb9ff..e68040c09 100644 --- a/loopy/target/opencl/__init__.py +++ b/loopy/target/opencl/__init__.py @@ -237,6 +237,43 @@ class OpenCLTarget(CTarget): def get_vector_dtype(self, base, count): return vec.types[base, count] + def wrap_function_declaration(self, kernel, fdecl): + from cgen.opencl import CLKernel, CLRequiredWorkGroupSize + return CLRequiredWorkGroupSize( + kernel.get_grid_sizes_as_exprs()[1], + CLKernel(fdecl)) + + def generate_code(self, kernel, codegen_state, impl_arg_info): + code, implemented_domains = ( + super(OpenCLTarget, self).generate_code( + kernel, codegen_state, impl_arg_info)) + + from loopy.tools import remove_common_indentation + code = ( + remove_common_indentation(""" + #define lid(N) ((%(idx_ctype)s) get_local_id(N)) + #define gid(N) ((%(idx_ctype)s) get_group_id(N)) + """ % dict(idx_ctype=self.dtype_to_typename(kernel.index_dtype))) + + "\n\n" + + code) + + return code, implemented_domains + + def generate_body(self, kernel, codegen_state): + body, implemented_domains = ( + super(OpenCLTarget, self).generate_body(kernel, codegen_state)) + + from loopy.kernel.data import ImageArg + + if any(isinstance(arg, ImageArg) for arg in kernel.args): + from cgen import Value, Const, Initializer + body.contents.insert(0, + Initializer(Const(Value("sampler_t", "loopy_sampler")), + "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP " + "| CLK_FILTER_NEAREST")) + + return body, implemented_domains + # }}} # vim: foldmethod=marker diff --git a/loopy/tools.py b/loopy/tools.py index 0c8898bb1..1f4716067 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -98,6 +98,8 @@ class LoopyKeyBuilder(KeyBuilderBase): # }}} +# {{{ picklable dtype + class PicklableDtype(object): """This object works around several issues with pickling :class:`numpy.dtype` objects. It does so by serving as a picklable wrapper around the original @@ -157,4 +159,38 @@ class PicklableDtype(object): def assert_has_target(self): assert self.target is not None +# }}} + + +# {{{ remove common indentation + +def remove_common_indentation(code): + if "\n" not in code: + return code + + # accommodate pyopencl-ish syntax highlighting + code = code.lstrip("//CL//") + + if not code.startswith("\n"): + return code + + lines = code.split("\n") + while lines[0].strip() == "": + lines.pop(0) + while lines[-1].strip() == "": + lines.pop(-1) + + if lines: + base_indent = 0 + while lines[0][base_indent] in " \t": + base_indent += 1 + + for line in lines[1:]: + if line[:base_indent].strip(): + raise ValueError("inconsistent indentation") + + return "\n".join(line[base_indent:] for line in lines) + +# }}} + # vim: foldmethod=marker -- GitLab