diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 7817546eb39367688bbf28ff7666083351159899..c549938d13b62438593b153b34b9fa838b0b385a 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -141,6 +141,9 @@ class IdentityMapperBase(object): # leaf -- no need to rebuild return expr + def map_function_symbol(self, expr, *args): + return expr + def map_call(self, expr, *args): return expr.__class__( self.rec(expr.function, *args), diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index d4e7b87cf14d7d985fc4c236fb6a9755b60a48bf..e299c004848dd7234ef01d95b6387fa05472ccf0 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -38,6 +38,9 @@ class DependencyMapper(CombineMapper): def map_variable(self, expr): return set([expr]) + def map_function_symbol(self, expr): + return set() + def map_call(self, expr): if self.include_calls == "descend_args": return self.combine( diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 4c73a735bbf00c5910261fe41bb86ecc40ed54ff..192641452c904835285674edc5e915d4045c6fe2 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -60,6 +60,9 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): def map_variable(self, expr, enclosing_prec): return expr.name + def map_function_symbol(self, expr, enclosing_prec): + return expr.__class__.__name__ + def map_call(self, expr, enclosing_prec): return self.format("%s(%s)", self.rec(expr.function, PREC_CALL), diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index fe1c2b3f9846303bf24d6eee493c67ea35b21c8e..f16303f433b64b0b1e9e984ae434fee0ef32f1bc 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -232,11 +232,47 @@ class Variable(Leaf): def get_mapper_method(self, mapper): return mapper.map_variable + + + +class FunctionSymbol(AlgebraicLeaf): + """Represents the name of a function. + + May optionally have an `arg_count` attribute, which will + allow `Call` to check the number of arguments. + """ + + def __getinitargs__(self): + return () + + def is_equal(self, other): + return self.__class__ == other.__class + + def get_hash(self): + return hash(self.__class__) + + def get_mapper_method(self, mapper): + return mapper.map_function_symbol + + + + + class Call(AlgebraicLeaf): def __init__(self, function, parameters): self.function = function self.parameters = parameters + try: + arg_count = self.function.arg_count + except AttributeError: + pass + else: + if len(self.parameters) != arg_count: + raise TypeError("%s called with wrong number of arguments " + "(need %d, got %d)" % ( + self.function, arg_count, len(parameters))) + def __getinitargs__(self): return self.function, self.parameters