From 632f5db320b5b79c3ff072de10e560d75897ca18 Mon Sep 17 00:00:00 2001 From: Andreas Stock <astock@dam.brown.edu> Date: Tue, 2 Jun 2009 16:30:26 -0400 Subject: [PATCH] Add FunctionSymbol as a way of treating predefined functions. --- pymbolic/mapper/__init__.py | 3 +++ pymbolic/mapper/dependency.py | 3 +++ pymbolic/mapper/stringifier.py | 3 +++ pymbolic/primitives.py | 36 ++++++++++++++++++++++++++++++++++ 4 files changed, 45 insertions(+) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 7817546..c549938 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -141,6 +141,9 @@ class IdentityMapperBase(object): # leaf -- no need to rebuild return expr + def map_function_symbol(self, expr, *args): + return expr + def map_call(self, expr, *args): return expr.__class__( self.rec(expr.function, *args), diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index d4e7b87..e299c00 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -38,6 +38,9 @@ class DependencyMapper(CombineMapper): def map_variable(self, expr): return set([expr]) + def map_function_symbol(self, expr): + return set() + def map_call(self, expr): if self.include_calls == "descend_args": return self.combine( diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 4c73a73..1926414 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -60,6 +60,9 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): def map_variable(self, expr, enclosing_prec): return expr.name + def map_function_symbol(self, expr, enclosing_prec): + return expr.__class__.__name__ + def map_call(self, expr, enclosing_prec): return self.format("%s(%s)", self.rec(expr.function, PREC_CALL), diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index fe1c2b3..f16303f 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -232,11 +232,47 @@ class Variable(Leaf): def get_mapper_method(self, mapper): return mapper.map_variable + + + +class FunctionSymbol(AlgebraicLeaf): + """Represents the name of a function. + + May optionally have an `arg_count` attribute, which will + allow `Call` to check the number of arguments. + """ + + def __getinitargs__(self): + return () + + def is_equal(self, other): + return self.__class__ == other.__class + + def get_hash(self): + return hash(self.__class__) + + def get_mapper_method(self, mapper): + return mapper.map_function_symbol + + + + + class Call(AlgebraicLeaf): def __init__(self, function, parameters): self.function = function self.parameters = parameters + try: + arg_count = self.function.arg_count + except AttributeError: + pass + else: + if len(self.parameters) != arg_count: + raise TypeError("%s called with wrong number of arguments " + "(need %d, got %d)" % ( + self.function, arg_count, len(parameters))) + def __getinitargs__(self): return self.function, self.parameters -- GitLab