diff --git a/.gitignore b/.gitignore index 74bc5dc7d2599153fbf3de05bc1cd5f5d1934b8f..5cdaf06d4400b9ad7e69bf20e721a5e0dd38c834 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .settings *~ .*.sw[po] +.sw[po] *.dat *.pyc build diff --git a/contrib/fortran-to-opencl/translate.py b/contrib/fortran-to-opencl/translate.py index d61a3775e317de2a4730dad154fc032d03952f52..d3c35a54cbbb77ec66fca6e189e90c8c6dcd89f2 100644 --- a/contrib/fortran-to-opencl/translate.py +++ b/contrib/fortran-to-opencl/translate.py @@ -4,6 +4,7 @@ import numpy as np import re from pymbolic.parser import Parser as ExpressionParserBase from pymbolic.mapper import CombineMapper +import pymbolic.primitives from pymbolic.mapper.c_code import CCodeMapper as CCodeMapperBase from warnings import warn @@ -80,6 +81,16 @@ _not = intern("not") _and = intern("and") _or = intern("or") +class TypedLiteral(pymbolic.primitives.Leaf): + def __init__(self, value, dtype): + self.value = value + self.dtype = np.dtype(dtype) + + def __getinitargs__(self): + return self.value, self.dtype + + mapper_method = intern("map_literal") + class FortranExpressionParser(ExpressionParserBase): # FIXME double/single prec literals @@ -99,15 +110,26 @@ class FortranExpressionParser(ExpressionParserBase): def __init__(self, tree_walker): self.tree_walker = tree_walker + _PREC_FUNC_ARGS = 1 + def parse_terminal(self, pstate): scope = self.tree_walker.scope_stack[-1] from pymbolic.primitives import Subscript, Call, Variable from pymbolic.parser import ( - _identifier, _openpar, _closepar) + _identifier, _openpar, _closepar, _float) next_tag = pstate.next_tag() - if next_tag is _identifier: + if next_tag is _float: + value = pstate.next_str_and_advance().lower() + if "d" in value: + dtype = np.float64 + else: + dtype = np.float32 + + return TypedLiteral(value.replace("d", "e"), dtype) + + elif next_tag is _identifier: name = pstate.next_str_and_advance() if pstate.is_at_end() or pstate.next_tag() is not _openpar: @@ -130,7 +152,7 @@ class FortranExpressionParser(ExpressionParserBase): pstate.advance() left_exp = cls(left_exp, ()) else: - args = self.parse_expression(pstate) + args = self.parse_expression(pstate, self._PREC_FUNC_ARGS) if not isinstance(args, tuple): args = (args,) left_exp = cls(left_exp, args) @@ -195,11 +217,18 @@ class FortranExpressionParser(ExpressionParserBase): left_exp, did_something = ExpressionParserBase.parse_postfix( self, pstate, min_precedence, left_exp) - if isinstance(left_exp, tuple): + if isinstance(left_exp, tuple) and min_precedence < self._PREC_FUNC_ARGS: # this must be a complex literal assert len(left_exp) == 2 r, i = left_exp - left_exp = float(r) + 1j*float(i) + + dtype = (r.dtype.type(0) + i.dtype.type(0)) + if dtype == np.float32: + dtype = np.complex64 + else: + dtype = np.complex128 + + left_exp = TypedLiteral(left_exp, dtype) return left_exp, did_something @@ -214,6 +243,9 @@ class TypeInferenceMapper(CombineMapper): def combine(self, dtypes): return sum(dtype.type(1) for dtype in dtypes).dtype + def map_literal(self, expr): + return expr.dtype + def map_constant(self, expr): return np.array(expr).dtype @@ -357,8 +389,9 @@ class CCodeMapper(ComplexCCodeMapper): # Stuff that deals with generating real-valued code # from complex code goes above. - def __init__(self, scope): + def __init__(self, translator, scope): ComplexCCodeMapper.__init__(self, scope.get_type_inference_mapper()) + self.translator = translator self.scope = scope def map_subscript(self, expr, enclosing_prec): @@ -416,11 +449,27 @@ class CCodeMapper(ComplexCCodeMapper): name = expr.name shape = self.scope.get_shape(name) name = self.scope.translate_var_name(name) - if expr.name in self.scope.arg_names or shape not in [(), None]: + if expr.name in self.scope.arg_names: + arg_idx = self.scope.arg_names.index(name) + if self.translator.arg_needs_pointer( + self.scope.subprogram_name, arg_idx): + return "*"+name + else: + return name + elif shape not in [(), None]: return "*"+name else: return name + def map_literal(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_NONE + if expr.dtype.kind == "c": + r, i = expr.value + return "{ %s, %s }" % (self.rec(r, PREC_NONE), self.rec(i, PREC_NONE)) + else: + return expr.value + + # }}} class Scope(object): @@ -492,11 +541,9 @@ class Scope(object): -class TreeWalker(object): - def __init__(self, addr_space_hints, force_casts): +class FTreeWalkerBase(object): + def __init__(self): self.scope_stack = [] - self.addr_space_hints = addr_space_hints - self.force_casts = force_casts self.expr_parser = FortranExpressionParser(self) @@ -519,20 +566,273 @@ class TreeWalker(object): % (type(self).__name__, type(expr))) + ENTITY_RE = re.compile( + r"^(?P[_0-9a-zA-Z]+)" + "(\((?P[-+*0-9:a-zA-Z,]+)\))?$") + + def parse_dimension_specs(self, dim_decls): + def parse_bounds(bounds_str): + start_end = bounds_str.split(":") + + assert 1 <= len(start_end) <= 2 + + return (self.parse_expr(s) for s in start_end) + + for decl in dim_decls: + entity_match = self.ENTITY_RE.match(decl) + assert entity_match + + groups = entity_match.groupdict() + name = groups["name"] + assert name + + if groups["shape"]: + shape = [parse_bounds(s) for s in groups["shape"].split(",")] + else: + shape = None + + yield name, shape + def __call__(self, expr, *args, **kwargs): return self.rec(expr, *args, **kwargs) + # {{{ expressions + + def parse_expr(self, expr_str): + return self.expr_parser(expr_str) + + # }}} + + + +class ArgumentAnalayzer(FTreeWalkerBase): + def __init__(self): + FTreeWalkerBase.__init__(self) + + # map (func, arg_nr) to + # 'w' for 'needs pointer' + # [] for no obstacle to de-pointerification known + # [(func_name, arg_nr), ...] # depends on how this arg is used + + self.arg_usage_info = {} + + def arg_needs_pointer(self, func, arg_nr): + data = self.arg_usage_info.get((func, arg_nr), []) + + if isinstance(data, list): + return any( + self.arg_needs_pointer(sub_func, sub_arg_nr) + for sub_func, sub_arg_nr in data) + + return True + + # {{{ map_XXX functions + + def map_BeginSource(self, node): + scope = Scope(None) + self.scope_stack.append(scope) + + for c in node.content: + self.rec(c) + + def map_Subroutine(self, node): + scope = Scope(node.name, list(node.args)) + self.scope_stack.append(scope) + + for c in node.content: + self.rec(c) + + self.scope_stack.pop() + + def map_EndSubroutine(self, node): + pass + + def map_Implicit(self, node): + pass + + # {{{ types, declarations + + def map_Equivalence(self, node): + raise NotImplementedError("equivalence") + + def map_Dimension(self, node): + scope = self.scope_stack[-1] + + for name, shape in self.parse_dimension_specs(node.items): + if name in scope.arg_names: + arg_idx = scope.arg_names.index(name) + self.arg_usage_info[scope.subprogram_name, arg_idx] = "w" + + def map_External(self, node): + pass + + def map_type_decl(self, node): + scope = self.scope_stack[-1] + + for name, shape in self.parse_dimension_specs(node.entity_decls): + if shape is not None and name in scope.arg_names: + arg_idx = scope.arg_names.index(name) + self.arg_usage_info[scope.subprogram_name, arg_idx] = "w" + + map_Logical = map_type_decl + map_Integer = map_type_decl + map_Real = map_type_decl + map_Complex = map_type_decl + + # }}} + + def map_Data(self, node): + pass + + def map_Parameter(self, node): + raise NotImplementedError("parameter") + + # {{{ I/O + + def map_Open(self, node): + pass + + def map_Format(self, node): + pass + + def map_Write(self, node): + pass + + def map_Print(self, node): + pass + + def map_Read1(self, node): + pass + + # }}} + + def map_Assignment(self, node): + scope = self.scope_stack[-1] + + lhs = self.parse_expr(node.variable) + from pymbolic.primitives import Subscript, Call + if isinstance(lhs, Subscript): + lhs_name = lhs.aggregate.name + elif isinstance(lhs, Call): + # in absence of dim info, subscripts get parsed as calls + lhs_name = lhs.function.name + else: + lhs_name = lhs.name + + if lhs_name in scope.arg_names: + arg_idx = scope.arg_names.index(lhs_name) + self.arg_usage_info[scope.subprogram_name, arg_idx] = "w" + + def map_Allocate(self, node): + raise NotImplementedError("allocate") + + def map_Deallocate(self, node): + raise NotImplementedError("deallocate") + + def map_Save(self, node): + raise NotImplementedError("save") + + def map_Line(self, node): + raise NotImplementedError + + def map_Program(self, node): + raise NotImplementedError + + def map_Entry(self, node): + raise NotImplementedError + + # {{{ control flow + + def map_Goto(self, node): + pass + + def map_Call(self, node): + scope = self.scope_stack[-1] + + from pymbolic.primitives import Subscript, Variable + for i, arg_str in enumerate(node.items): + arg = self.parse_expr(arg_str) + if isinstance(arg, (Variable, Subscript)): + if isinstance(arg, Subscript): + arg_name = arg.aggregate.name + else: + arg_name = arg.name + + if arg_name in scope.arg_names: + arg_idx = scope.arg_names.index(arg_name) + arg_usage = self.arg_usage_info.setdefault( + (scope.subprogram_name, arg_idx), + []) + if isinstance(arg_usage, list): + arg_usage.append((node.designator, i)) + + def map_Return(self, node): + pass + + def map_ArithmeticIf(self, node): + pass + + def map_If(self, node): + for c in node.content: + self.rec(c) + + def map_IfThen(self, node): + for c in node.content: + self.rec(c) + + def map_ElseIf(self, node): + pass + + def map_Else(self, node): + pass + + def map_EndIfThen(self, node): + pass + + def map_Do(self, node): + for c in node.content: + self.rec(c) + + def map_EndDo(self, node): + pass + + def map_Continue(self, node): + pass + + def map_Stop(self, node): + pass + + def map_Comment(self, node): + pass + + # }}} + + # }}} + + + + + +# {{{ translator + +class F2CLTranslator(FTreeWalkerBase): + def __init__(self, addr_space_hints, force_casts, arg_info, + use_restrict_pointers): + FTreeWalkerBase.__init__(self) + self.addr_space_hints = addr_space_hints + self.force_casts = force_casts + self.arg_info = arg_info + self.use_restrict_pointers = use_restrict_pointers + + def arg_needs_pointer(self, subprogram_name, arg_index): + return self.arg_info.arg_needs_pointer(subprogram_name, arg_index) + # {{{ declaration helpers def get_declarator(self, name): scope = self.scope_stack[-1] return POD(scope.get_type(name), name) - def format_constant(self, c): - if isinstance(c, complex): - return "{ %r, %r }" % (c.real, c.imag) - else: - return repr(c) def get_declarations(self): scope = self.scope_stack[-1] @@ -540,6 +840,9 @@ class TreeWalker(object): result = [] pre_func_decl = [] + def gen_shape(start_end): + return ":".join(self.gen_expr(s) for s in start_end) + for name in sorted(scope.known_names()): shape = scope.dim_map.get(name) @@ -547,7 +850,7 @@ class TreeWalker(object): dim_stmt = cgen.Statement( "dimension \"fortran\" %s[%s]" % ( scope.translate_var_name(name), - ", ".join(self.gen_expr(s) for s in shape) + ", ".join(gen_shape(s) for s in shape) )) # cannot omit 'dimension' decl even for rank-1 args: @@ -563,7 +866,7 @@ class TreeWalker(object): result.append( cgen.Initializer( self.get_declarator(name), - self.format_constant(data[0]) + self.gen_expr(data[0]) )) else: from cgen.opencl import CLConstant @@ -572,7 +875,7 @@ class TreeWalker(object): CLConstant( cgen.ArrayOf(self.get_declarator( "%s_%s" % (scope.subprogram_name, name)))), - "{ %s }" % ",\n".join(self.format_constant(x) for x in data) + "{ %s }" % ",\n".join(self.gen_expr(x) for x in data) )) else: if name not in scope.arg_names: @@ -615,7 +918,7 @@ class TreeWalker(object): assert not node.prefix assert not hasattr(node, "suffix") - scope = Scope(node.name, set(node.args)) + scope = Scope(node.name, list(node.args)) self.scope_stack.append(scope) body = self.map_statement_list(node.content) @@ -626,14 +929,18 @@ class TreeWalker(object): if isinstance(body[-1], cgen.Statement) and body[-1].text == "return": body.pop() - def get_arg_decl(arg_name): + def get_arg_decl(arg_idx, arg_name): decl = self.get_declarator(arg_name) - hint = self.addr_space_hints.get((node.name, arg_name)) - if hint: - decl = hint(cgen.Pointer(decl)) - else: - decl = cgen.RestrictPointer(decl) + if self.arg_needs_pointer(node.name, arg_idx): + hint = self.addr_space_hints.get((node.name, arg_name)) + if hint: + decl = hint(cgen.Pointer(decl)) + else: + if self.use_restrict_pointers: + decl = cgen.RestrictPointer(decl) + else: + decl = cgen.Pointer(decl) return decl @@ -641,7 +948,7 @@ class TreeWalker(object): result = cgen.FunctionBody( cgen.FunctionDeclaration( cgen.Value("void", node.name), - [get_arg_decl(arg) for arg in node.args] + [get_arg_decl(i, arg) for i, arg in enumerate(node.args)] ), cgen.Block(body)) @@ -683,31 +990,11 @@ class TreeWalker(object): ("complex", "16"): np.complex128, ("complex", "32"): np.complex256, + ("integer", ""): np.int32, ("integer", "4"): np.int32, ("complex", "8"): np.int64, } - ENTITY_RE = re.compile( - r"^(?P[_0-9a-zA-Z]+)" - "(\((?P[-+*0-9:a-zA-Z,]+)\))?$") - - def parse_dimension_specs(self, dim_decls): - for decl in dim_decls: - entity_match = self.ENTITY_RE.match(decl) - assert entity_match - - groups = entity_match.groupdict() - name = groups["name"] - assert name - - if groups["shape"]: - # FIXME colons - shape = [self.parse_expr(s) for s in groups["shape"].split(",")] - else: - shape = None - - yield name, shape - def dtype_from_stmt(self, stmt): length, kind = stmt.selector assert not kind @@ -833,7 +1120,9 @@ class TreeWalker(object): def map_Call(self, node): def transform_arg(i, arg_str): expr = self.parse_expr(arg_str) - result = "&%s" % self.gen_expr(expr) + result = self.gen_expr(expr) + if self.arg_needs_pointer(node.designator, i): + result = "&"+result cast = self.force_casts.get( (node.designator, i)) @@ -891,9 +1180,15 @@ class TreeWalker(object): i += 1 end_block() + def block_or_none(body): + if not body: + return None + else: + return cgen.block_if_necessary(body) + return cgen.make_multiple_ifs( blocks_and_conds, - cgen.block_if_necessary(else_block)) + block_or_none(else_block)) def map_EndIfThen(self, node): return [] @@ -936,9 +1231,8 @@ class TreeWalker(object): else: raise NotImplementedError("unbounded do loop") - def map_EndDo(self, node): - raise NotImplementedError + return [] def map_Continue(self, node): return cgen.Statement("label_%s:" % node.label) @@ -958,29 +1252,37 @@ class TreeWalker(object): # {{{ expressions - def parse_expr(self, expr_str): - return self.expr_parser(expr_str) - def gen_expr(self, expr): scope = self.scope_stack[-1] - return CCodeMapper(scope)(expr) + return CCodeMapper(self, scope)(expr) def transform_expr(self, expr_str): return self.gen_expr(self.expr_parser(expr_str)) # }}} +# }}} + def f2cl(source, free_form=False, strict=True, - addr_space_hints={}, force_casts={}): + addr_space_hints={}, force_casts={}, + do_arg_analysis=True, + use_restrict_pointers=False, + try_compile=False): from fparser import api tree = api.parse(source, isfree=free_form, isstrict=strict, analyze=False, ignore_comments=False) - source = TreeWalker(addr_space_hints, force_casts)(tree) + + arg_info = ArgumentAnalayzer() + if do_arg_analysis: + arg_info(tree) + + source = F2CLTranslator(addr_space_hints, force_casts, + arg_info, use_restrict_pointers=use_restrict_pointers)(tree) func_decls = [] for entry in source: @@ -989,28 +1291,42 @@ def f2cl(source, free_form=False, strict=True, mod = cgen.Module(func_decls + [cgen.Line()] + source) + #open("pre-cnd.cl", "w").write(str(mod)) + from cnd import transform_cl str_mod = transform_cl(str(mod)) + + if try_compile: + import pyopencl as cl + ctx = cl.create_some_context() + cl.Program(ctx, """ + #pragma OPENCL EXTENSION cl_khr_fp64: enable + #include + """).build() return str_mod +def f2cl_files(source_file, target_file, **kwargs): + mod = f2cl(open(source_file).read(), **kwargs) + open(target_file, "w").write(mod) + + + + if __name__ == "__main__": from cgen.opencl import CLConstant - mod = f2cl(open("hank107.f").read(), + f2cl_files("hank107.f", "hank107.cl", addr_space_hints={ ("hank107p", "p"): CLConstant, ("hank107pc", "p"): CLConstant, }, force_casts={ ("hank107p", 0): "__constant cdouble_t *", - #("hank107pc", 0): "__constant double *", - } - ) + }) - open("hank107.cl", "w").write(mod) + f2cl_files("cdjseval2d.f", "cdjseval2d.cl", try_compile=True) # vim: foldmethod=marker -