diff --git a/loopy/expression.py b/loopy/expression.py index 3269bc09f064f57857eaa5218c8370383e0f735e..ec52cb9d2acd6ed1dd82043fe013cd4a331d6799 100644 --- a/loopy/expression.py +++ b/loopy/expression.py @@ -175,6 +175,9 @@ class VectorizabilityChecker(RecursiveMapper): # FIXME: Do this more carefully raise Unvectorizable() + def vectorizability_map_fused_multiply_add(self, expr): + return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op))) + # }}} # vim: fdm=marker diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 74bb5c1d1a917fd00edad3bfcd2d5ba241d1ff49..25f0e15bbf68c48460adfe8b63b0f31c7f2a9ebe 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -109,6 +109,12 @@ class IdentityMapperMixin(object): map_rule_argument = map_group_hw_index + def map_fused_multiply_add(self, expr, *args): + return FusedMultiplyAdd(self.rec(expr.mul_op1, *args), + self.rec(expr.mul_op2, *args), + self.rec(expr.add_op, *args), + ) + class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass @@ -151,6 +157,14 @@ class WalkMapper(WalkMapperBase): map_rule_argument = map_group_hw_index + def map_fused_multiply_add(self, expr, *args): + if not self.visit(expr): + return + + self.rec(expr.mul_op1, *args) + self.rec(expr.mul_op2, *args) + self.rec(expr.add_op, *args) + class CallbackMapper(CallbackMapperBase, IdentityMapper): map_reduction = CallbackMapperBase.map_constant @@ -209,6 +223,12 @@ class StringifyMapper(StringifyMapperBase): def map_rule_argument(self, expr, enclosing_prec): return "" % expr.index + def map_fused_multiply_add(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_NONE + return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE), + self.rec(expr.mul_op2, PREC_NONE), + self.rec(expr.add_op, PREC_NONE)) + class UnidirectionalUnifier(UnidirectionalUnifierBase): def map_reduction(self, expr, other, unis): @@ -263,6 +283,12 @@ class DependencyMapper(DependencyMapperBase): map_linear_subscript = DependencyMapperBase.map_subscript + def map_fused_multiply_add(self, expr): + return self.combine((self.rec(expr.mul_op1), + self.rec(expr.mul_op2), + self.rec(expr.add_op) + )) + class SubstitutionRuleExpander(IdentityMapper): def __init__(self, rules): @@ -540,6 +566,25 @@ class RuleArgument(Expression): mapper_method = intern("map_rule_argument") + +class FusedMultiplyAdd(Expression): + """ Represents an FMA operation """ + + init_arg_names = ("mul_op1", "mul_op2", "add_op") + + def __init__(self, mul_op1, mul_op2, add_op): + self.mul_op1 = mul_op1 + self.mul_op2 = mul_op2 + self.add_op = add_op + + def __getinitargs__(self): + return (self.mul_op1, self.mul_op2, self.add_op) + + def stringifier(self): + return StringifyMapper + + mapper_method = intern("map_fused_multiply_add") + # }}} @@ -914,7 +959,7 @@ class VarToTaggedVarMapper(IdentityMapper): class FunctionToPrimitiveMapper(IdentityMapper): - """Looks for invocations of a function called 'cse' or 'reduce' and + """Looks for invocations of a function called 'cse', 'reduce' or 'fma' and turns those into the actual pymbolic primitives used for that. """ @@ -982,6 +1027,11 @@ class FunctionToPrimitiveMapper(IdentityMapper): else: raise TypeError("if takes three arguments") + elif name == 'fma': + if len(expr.parameters) == 3: + return FusedMultiplyAdd(*tuple(self.rec(p) for p in expr.parameters)) + else: + raise TypeError("FMA takes 3 arguments: fma(a,b,c) -> a*b + c") else: # see if 'name' is an existing reduction op diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index bd5a74782dc5dce7bf82985bea3a7c77404d9d26..74d58c197ab692c54fdce807b6184e8fcd65738d 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -658,6 +658,9 @@ class ExpressionToCExpressionMapper(IdentityMapper): return base_impl(expr, type_context) + def map_fused_multiply_add(self, expr, type_context): + return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) + # }}} def map_group_hw_index(self, expr, type_context): diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 16be9605c1735180abf624cb8f600ef895fb8874..40ee9474b3f30d122f5631d6a42ac59f44ebe4d4 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -374,6 +374,10 @@ class TypeInferenceMapper(CombineMapper): return [result[0]] + # TODO This is a dummy implementation! + def map_fused_multiply_add(self, expr): + return self.rec(expr.mul_op1) + # }}}