diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py
index ebbe98072ef4789d9158ef4fa8ef3c098dc2413d..210570ae91d2a1d92b451b3b0b7161184b97cc2b 100644
--- a/pymbolic/mapper/__init__.py
+++ b/pymbolic/mapper/__init__.py
@@ -98,7 +98,7 @@ class Mapper(object):
     attribute.
     """
 
-    def handle_unsupported_expression(self, expr, *args):
+    def handle_unsupported_expression(self, expr, *args, **kwargs):
         """Mapper method that is invoked for
         :class:`pymbolic.primitives.Expression` subclasses for which a mapper
         method does not exist in this mapper.
@@ -135,35 +135,35 @@ class Mapper(object):
 
     rec = __call__
 
-    def map_variable(self, expr, *args):
-        return self.map_algebraic_leaf(expr, *args)
+    def map_variable(self, expr, *args, **kwargs):
+        return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_subscript(self, expr, *args):
-        return self.map_algebraic_leaf(expr, *args)
+    def map_subscript(self, expr, *args, **kwargs):
+        return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_call(self, expr, *args):
-        return self.map_algebraic_leaf(expr, *args)
+    def map_call(self, expr, *args, **kwargs):
+        return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_lookup(self, expr, *args):
-        return self.map_algebraic_leaf(expr, *args)
+    def map_lookup(self, expr, *args, **kwargs):
+        return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_if_positive(self, expr, *args):
-        return self.map_algebraic_leaf(expr, *args)
+    def map_if_positive(self, expr, *args, **kwargs):
+        return self.map_algebraic_leaf(expr, *args, **kwargs)
 
-    def map_rational(self, expr, *args):
-        return self.map_quotient(expr, *args)
+    def map_rational(self, expr, *args, **kwargs):
+        return self.map_quotient(expr, *args, **kwargs)
 
-    def map_foreign(self, expr, *args):
+    def map_foreign(self, expr, *args, **kwargs):
         """Mapper method dispatch for non-:mod:`pymbolic` objects."""
 
         if isinstance(expr, primitives.VALID_CONSTANT_CLASSES):
-            return self.map_constant(expr, *args)
+            return self.map_constant(expr, *args, **kwargs)
         elif isinstance(expr, list):
-            return self.map_list(expr, *args)
+            return self.map_list(expr, *args, **kwargs)
         elif isinstance(expr, tuple):
-            return self.map_tuple(expr, *args)
+            return self.map_tuple(expr, *args, **kwargs)
         elif is_numpy_array(expr):
-            return self.map_numpy_array(expr, *args)
+            return self.map_numpy_array(expr, *args, **kwargs)
         else:
             raise ValueError(
                     "%s encountered invalid foreign object: %s" % (
@@ -195,65 +195,67 @@ class CombineMapper(RecursiveMapper):
     :class:`pymbolic.mapper.dependency.DependencyMapper` is another example.
     """
 
-    def map_call(self, expr, *args):
+    def map_call(self, expr, *args, **kwargs):
         return self.combine(
-                (self.rec(expr.function, *args),) +
+                (self.rec(expr.function, *args, **kwargs),) +
                 tuple(
-                    self.rec(child, *args) for child in expr.parameters)
+                    self.rec(child, *args, **kwargs) for child in expr.parameters)
                 )
 
-    def map_call_with_kwargs(self, expr, *args):
+    def map_call_with_kwargs(self, expr, *args, **kwargs):
         return self.combine(
-                (self.rec(expr.function, *args),)
+                (self.rec(expr.function, *args, **kwargs),)
                 + tuple(
-                    self.rec(child, *args) for child in expr.parameters)
+                    self.rec(child, *args, **kwargs)
+                    for child in expr.parameters)
                 + tuple(
-                    self.rec(child, *args) for child in expr.kw_parameters.values())
+                    self.rec(child, *args, **kwargs)
+                    for child in expr.kw_parameters.values())
                 )
 
-    def map_subscript(self, expr, *args):
+    def map_subscript(self, expr, *args, **kwargs):
         return self.combine(
-                [self.rec(expr.aggregate, *args),
-                    self.rec(expr.index, *args)])
+                [self.rec(expr.aggregate, *args, **kwargs),
+                    self.rec(expr.index, *args, **kwargs)])
 
-    def map_lookup(self, expr, *args):
-        return self.rec(expr.aggregate, *args)
+    def map_lookup(self, expr, *args, **kwargs):
+        return self.rec(expr.aggregate, *args, **kwargs)
 
-    def map_sum(self, expr, *args):
-        return self.combine(self.rec(child, *args)
+    def map_sum(self, expr, *args, **kwargs):
+        return self.combine(self.rec(child, *args, **kwargs)
                 for child in expr.children)
 
     map_product = map_sum
 
-    def map_quotient(self, expr, *args):
+    def map_quotient(self, expr, *args, **kwargs):
         return self.combine((
-            self.rec(expr.numerator, *args),
-            self.rec(expr.denominator, *args)))
+            self.rec(expr.numerator, *args, **kwargs),
+            self.rec(expr.denominator, *args, **kwargs)))
 
     map_floor_div = map_quotient
     map_remainder = map_quotient
 
-    def map_power(self, expr, *args):
+    def map_power(self, expr, *args, **kwargs):
         return self.combine((
-                self.rec(expr.base, *args),
-                self.rec(expr.exponent, *args)))
+                self.rec(expr.base, *args, **kwargs),
+                self.rec(expr.exponent, *args, **kwargs)))
 
-    def map_polynomial(self, expr, *args):
+    def map_polynomial(self, expr, *args, **kwargs):
         return self.combine(
-                (self.rec(expr.base, *args),) +
+                (self.rec(expr.base, *args, **kwargs),) +
                 tuple(
-                    self.rec(coeff, *args) for exp, coeff in expr.data)
+                    self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data)
                 )
 
-    def map_left_shift(self, expr, *args):
+    def map_left_shift(self, expr, *args, **kwargs):
         return self.combine(
-                self.rec(expr.shiftee, *args),
-                self.rec(expr.shift, *args))
+                self.rec(expr.shiftee, *args, **kwargs),
+                self.rec(expr.shift, *args, **kwargs))
 
     map_right_shift = map_left_shift
 
-    def map_bitwise_not(self, expr, *args):
-        return self.rec(expr.child, *args)
+    def map_bitwise_not(self, expr, *args, **kwargs):
+        return self.rec(expr.child, *args, **kwargs)
     map_bitwise_or = map_sum
     map_bitwise_xor = map_sum
     map_bitwise_and = map_sum
@@ -262,27 +264,27 @@ class CombineMapper(RecursiveMapper):
     map_logical_and = map_sum
     map_logical_or = map_sum
 
-    def map_comparison(self, expr, *args):
+    def map_comparison(self, expr, *args, **kwargs):
         return self.combine((
-            self.rec(expr.left, *args),
-            self.rec(expr.right, *args)))
+            self.rec(expr.left, *args, **kwargs),
+            self.rec(expr.right, *args, **kwargs)))
 
     map_max = map_sum
     map_min = map_sum
 
-    def map_list(self, expr, *args):
-        return self.combine(self.rec(child, *args) for child in expr)
+    def map_list(self, expr, *args, **kwargs):
+        return self.combine(self.rec(child, *args, **kwargs) for child in expr)
 
     map_tuple = map_list
 
-    def map_numpy_array(self, expr, *args):
+    def map_numpy_array(self, expr, *args, **kwargs):
         return self.combine(self.rec(el) for el in expr.flat)
 
-    def map_multivector(self, expr, *args):
+    def map_multivector(self, expr, *args, **kwargs):
         return self.combine(self.rec(coeff) for bits, coeff in expr.data.iteritems())
 
-    def map_common_subexpression(self, expr, *args):
-        return self.rec(expr.child, *args)
+    def map_common_subexpression(self, expr, *args, **kwargs):
+        return self.rec(expr.child, *args, **kwargs)
 
     def map_if_positive(self, expr):
         return self.combine([
@@ -333,83 +335,83 @@ class IdentityMapper(Mapper):
     See :ref:`custom-manipulation` for an example of the
     manipulations that can be implemented this way.
     """
-    def map_constant(self, expr, *args):
+    def map_constant(self, expr, *args, **kwargs):
         # leaf -- no need to rebuild
         return expr
 
-    def map_variable(self, expr, *args):
+    def map_variable(self, expr, *args, **kwargs):
         # leaf -- no need to rebuild
         return expr
 
-    def map_function_symbol(self, expr, *args):
+    def map_function_symbol(self, expr, *args, **kwargs):
         return expr
 
-    def map_call(self, expr, *args):
+    def map_call(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.function, *args),
-                tuple(self.rec(child, *args)
+                self.rec(expr.function, *args, **kwargs),
+                tuple(self.rec(child, *args, **kwargs)
                     for child in expr.parameters))
 
-    def map_call_with_kwargs(self, expr, *args):
+    def map_call_with_kwargs(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.function, *args),
-                tuple(self.rec(child, *args)
+                self.rec(expr.function, *args, **kwargs),
+                tuple(self.rec(child, *args, **kwargs)
                     for child in expr.parameters),
                 dict(
-                    (key, self.rec(val, *args))
+                    (key, self.rec(val, *args, **kwargs))
                     for key, val in expr.kw_parameters.iteritems())
                     )
 
-    def map_subscript(self, expr, *args):
+    def map_subscript(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.aggregate, *args),
-                self.rec(expr.index, *args))
+                self.rec(expr.aggregate, *args, **kwargs),
+                self.rec(expr.index, *args, **kwargs))
 
-    def map_lookup(self, expr, *args):
+    def map_lookup(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.aggregate, *args),
+                self.rec(expr.aggregate, *args, **kwargs),
                 expr.name)
 
-    def map_sum(self, expr, *args):
+    def map_sum(self, expr, *args, **kwargs):
         from pymbolic.primitives import flattened_sum
         return flattened_sum(tuple(
-            self.rec(child, *args) for child in expr.children))
+            self.rec(child, *args, **kwargs) for child in expr.children))
 
-    def map_product(self, expr, *args):
+    def map_product(self, expr, *args, **kwargs):
         from pymbolic.primitives import flattened_product
         return flattened_product(tuple(
-            self.rec(child, *args) for child in expr.children))
+            self.rec(child, *args, **kwargs) for child in expr.children))
 
-    def map_quotient(self, expr, *args):
-        return expr.__class__(self.rec(expr.numerator, *args),
-                              self.rec(expr.denominator, *args))
+    def map_quotient(self, expr, *args, **kwargs):
+        return expr.__class__(self.rec(expr.numerator, *args, **kwargs),
+                              self.rec(expr.denominator, *args, **kwargs))
 
     map_floor_div = map_quotient
     map_remainder = map_quotient
 
-    def map_power(self, expr, *args):
-        return expr.__class__(self.rec(expr.base, *args),
-                              self.rec(expr.exponent, *args))
+    def map_power(self, expr, *args, **kwargs):
+        return expr.__class__(self.rec(expr.base, *args, **kwargs),
+                              self.rec(expr.exponent, *args, **kwargs))
 
-    def map_polynomial(self, expr, *args):
-        return expr.__class__(self.rec(expr.base, *args),
-                              ((exp, self.rec(coeff, *args))
+    def map_polynomial(self, expr, *args, **kwargs):
+        return expr.__class__(self.rec(expr.base, *args, **kwargs),
+                              ((exp, self.rec(coeff, *args, **kwargs))
                                   for exp, coeff in expr.data))
 
-    def map_left_shift(self, expr, *args):
+    def map_left_shift(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.shiftee, *args),
-                self.rec(expr.shift, *args))
+                self.rec(expr.shiftee, *args, **kwargs),
+                self.rec(expr.shift, *args, **kwargs))
 
     map_right_shift = map_left_shift
 
-    def map_bitwise_not(self, expr, *args):
+    def map_bitwise_not(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.child, *args))
+                self.rec(expr.child, *args, **kwargs))
 
-    def map_bitwise_or(self, expr, *args):
+    def map_bitwise_or(self, expr, *args, **kwargs):
         return type(expr)(tuple(
-            self.rec(child, *args) for child in expr.children))
+            self.rec(child, *args, **kwargs) for child in expr.children))
 
     map_bitwise_xor = map_bitwise_or
     map_bitwise_and = map_bitwise_or
@@ -418,17 +420,17 @@ class IdentityMapper(Mapper):
     map_logical_or = map_bitwise_or
     map_logical_and = map_bitwise_or
 
-    def map_comparison(self, expr, *args):
+    def map_comparison(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.left, *args),
+                self.rec(expr.left, *args, **kwargs),
                 expr.operator,
-                self.rec(expr.right, *args))
+                self.rec(expr.right, *args, **kwargs))
 
-    def map_list(self, expr, *args):
-        return [self.rec(child, *args) for child in expr]
+    def map_list(self, expr, *args, **kwargs):
+        return [self.rec(child, *args, **kwargs) for child in expr]
 
-    def map_tuple(self, expr, *args):
-        return tuple(self.rec(child, *args) for child in expr)
+    def map_tuple(self, expr, *args, **kwargs):
+        return tuple(self.rec(child, *args, **kwargs) for child in expr)
 
     def map_numpy_array(self, expr):
         import numpy
@@ -438,8 +440,8 @@ class IdentityMapper(Mapper):
             result[i] = self.rec(expr[i])
         return result
 
-    def map_multivector(self, expr, *args):
-        return expr.map(lambda ch: self.rec(ch, *args))
+    def map_multivector(self, expr, *args, **kwargs):
+        return expr.map(lambda ch: self.rec(ch, *args, **kwargs))
 
     def map_common_subexpression(self, expr, *args, **kwargs):
         from pymbolic.primitives import is_zero
@@ -453,38 +455,38 @@ class IdentityMapper(Mapper):
                 expr.scope,
                 **expr.get_extra_properties())
 
-    def map_substitution(self, expr, *args):
+    def map_substitution(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.child, *args),
+                self.rec(expr.child, *args, **kwargs),
                 expr.variables,
-                tuple(self.rec(v, *args) for v in expr.values))
+                tuple(self.rec(v, *args, **kwargs) for v in expr.values))
 
-    def map_derivative(self, expr, *args):
+    def map_derivative(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.child, *args),
+                self.rec(expr.child, *args, **kwargs),
                 expr.variables)
 
-    def map_slice(self, expr, *args):
+    def map_slice(self, expr, *args, **kwargs):
         def do_map(expr):
             if expr is None:
                 return expr
             else:
-                return self.rec(expr, *args)
+                return self.rec(expr, *args, **kwargs)
 
         return type(expr)(
                 tuple(do_map(ch) for ch in expr.children))
 
-    def map_if_positive(self, expr, *args):
+    def map_if_positive(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.criterion, *args),
-                self.rec(expr.then, *args),
-                self.rec(expr.else_, *args))
+                self.rec(expr.criterion, *args, **kwargs),
+                self.rec(expr.then, *args, **kwargs),
+                self.rec(expr.else_, *args, **kwargs))
 
-    def map_if(self, expr, *args):
+    def map_if(self, expr, *args, **kwargs):
         return type(expr)(
-                self.rec(expr.condition, *args),
-                self.rec(expr.then, *args),
-                self.rec(expr.else_, *args))
+                self.rec(expr.condition, *args, **kwargs),
+                self.rec(expr.then, *args, **kwargs),
+                self.rec(expr.else_, *args, **kwargs))
 
 # }}}
 
@@ -496,99 +498,99 @@ class WalkMapper(RecursiveMapper):
     without propagating any result. Also calls :meth:`visit` for each
     visited subexpression.
 
-    .. method:: visit(expr, *args)
+    .. method:: visit(expr, *args, **kwargs)
     """
-    def map_constant(self, expr, *args):
-        self.visit(expr, *args)
+    def map_constant(self, expr, *args, **kwargs):
+        self.visit(expr, *args, **kwargs)
 
-    def map_variable(self, expr, *args):
-        self.visit(expr, *args)
+    def map_variable(self, expr, *args, **kwargs):
+        self.visit(expr, *args, **kwargs)
 
-    def map_function_symbol(self, expr, *args):
+    def map_function_symbol(self, expr, *args, **kwargs):
         self.visit(expr)
 
-    def map_call(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_call(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.function, *args)
+        self.rec(expr.function, *args, **kwargs)
         for child in expr.parameters:
-            self.rec(child, *args)
+            self.rec(child, *args, **kwargs)
 
-    def map_call_with_kwargs(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_call_with_kwargs(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.function, *args)
+        self.rec(expr.function, *args, **kwargs)
         for child in expr.parameters:
-            self.rec(child, *args)
+            self.rec(child, *args, **kwargs)
 
         for child in expr.kw_parameters.values():
-            self.rec(child, *args)
+            self.rec(child, *args, **kwargs)
 
-    def map_subscript(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_subscript(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.aggregate, *args)
-        self.rec(expr.index, *args)
+        self.rec(expr.aggregate, *args, **kwargs)
+        self.rec(expr.index, *args, **kwargs)
 
-    def map_lookup(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_lookup(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.aggregate, *args)
+        self.rec(expr.aggregate, *args, **kwargs)
 
-    def map_sum(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_sum(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
         for child in expr.children:
-            self.rec(child, *args)
+            self.rec(child, *args, **kwargs)
 
     map_product = map_sum
 
-    def map_quotient(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_quotient(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.numerator, *args)
-        self.rec(expr.denominator, *args)
+        self.rec(expr.numerator, *args, **kwargs)
+        self.rec(expr.denominator, *args, **kwargs)
 
     map_floor_div = map_quotient
     map_remainder = map_quotient
 
-    def map_power(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_power(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.base, *args)
-        self.rec(expr.exponent, *args)
+        self.rec(expr.base, *args, **kwargs)
+        self.rec(expr.exponent, *args, **kwargs)
 
-    def map_polynomial(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_polynomial(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.base, *args)
+        self.rec(expr.base, *args, **kwargs)
         for exp, coeff in expr.data:
-            self.rec(coeff, *args)
+            self.rec(coeff, *args, **kwargs)
 
-    def map_list(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_list(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
         for child in expr:
-            self.rec(child, *args)
+            self.rec(child, *args, **kwargs)
 
     map_tuple = map_list
 
-    def map_numpy_array(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_numpy_array(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
         from pytools import indices_in_shape
         for i in indices_in_shape(expr.shape):
-            self.rec(expr[i], *args)
+            self.rec(expr[i], *args, **kwargs)
 
     def map_multivector(self, expr, *args):
         if not self.visit(expr, *args):
@@ -598,72 +600,72 @@ class WalkMapper(RecursiveMapper):
             self.rec(coeff)
 
     def map_common_subexpression(self, expr, *args, **kwargs):
-        if not self.visit(expr, *args):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.child, *args)
+        self.rec(expr.child, *args, **kwargs)
 
-    def map_left_shift(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_left_shift(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.shift, *args)
-        self.rec(expr.shiftee, *args)
+        self.rec(expr.shift, *args, **kwargs)
+        self.rec(expr.shiftee, *args, **kwargs)
 
     map_right_shift = map_left_shift
 
-    def map_bitwise_not(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_bitwise_not(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.child, *args)
+        self.rec(expr.child, *args, **kwargs)
 
     map_bitwise_or = map_sum
     map_bitwise_xor = map_sum
     map_bitwise_and = map_sum
 
-    def map_comparison(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_comparison(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.left, *args)
-        self.rec(expr.right, *args)
+        self.rec(expr.left, *args, **kwargs)
+        self.rec(expr.right, *args, **kwargs)
 
     map_logical_not = map_bitwise_not
     map_logical_and = map_sum
     map_logical_or = map_sum
 
-    def map_if(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_if(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.condition, *args)
-        self.rec(expr.then, *args)
-        self.rec(expr.else_, *args)
+        self.rec(expr.condition, *args, **kwargs)
+        self.rec(expr.then, *args, **kwargs)
+        self.rec(expr.else_, *args, **kwargs)
 
-    def map_if_positive(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_if_positive(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.criterion, *args)
-        self.rec(expr.then, *args)
-        self.rec(expr.else_, *args)
+        self.rec(expr.criterion, *args, **kwargs)
+        self.rec(expr.then, *args, **kwargs)
+        self.rec(expr.else_, *args, **kwargs)
 
-    def map_substitution(self, expr, *args):
+    def map_substitution(self, expr, *args, **kwargs):
         if not self.visit(expr):
             return
 
-        self.rec(expr.child, *args)
+        self.rec(expr.child, *args, **kwargs)
         for v in expr.values:
-            self.rec(v, *args)
+            self.rec(v, *args, **kwargs)
 
-    def map_derivative(self, expr, *args):
-        if not self.visit(expr, *args):
+    def map_derivative(self, expr, *args, **kwargs):
+        if not self.visit(expr, *args, **kwargs):
             return
 
-        self.rec(expr.child, *args)
+        self.rec(expr.child, *args, **kwargs)
 
-    def visit(self, expr, *args):
+    def visit(self, expr, *args, **kwargs):
         return True
 
 # }}}
@@ -677,8 +679,8 @@ class CallbackMapper(RecursiveMapper):
         self.fallback_mapper = fallback_mapper
         fallback_mapper.rec = self.rec
 
-    def map_constant(self, expr, *args):
-        return self.function(expr, self, *args)
+    def map_constant(self, expr, *args, **kwargs):
+        return self.function(expr, self, *args, **kwargs)
 
     map_variable = map_constant
     map_function_symbol = map_constant
diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py
index b0fd24f6b35df628fe82ec3ab7be5ae02940c5a0..4b34a76b4c7bf2803a0a3a9d5a63e7c4570d13e2 100644
--- a/pymbolic/mapper/stringifier.py
+++ b/pymbolic/mapper/stringifier.py
@@ -89,9 +89,9 @@ class StringifyMapper(pymbolic.mapper.Mapper):
     def join(self, joiner, iterable):
         return self.format(joiner.join("%s" for i in iterable), *iterable)
 
-    def join_rec(self, joiner, iterable, prec):
+    def join_rec(self, joiner, iterable, prec, *args, **kwargs):
         f = joiner.join("%s" for i in iterable)
-        return self.format(f, *[self.rec(i, prec) for i in iterable])
+        return self.format(f, *[self.rec(i, prec, *args, **kwargs) for i in iterable])
 
     def parenthesize(self, s):
         return "(%s)" % s
@@ -106,14 +106,14 @@ class StringifyMapper(pymbolic.mapper.Mapper):
 
     # {{{ mappings
 
-    def handle_unsupported_expression(self, victim, enclosing_prec):
+    def handle_unsupported_expression(self, victim, enclosing_prec, *args, **kwargs):
         strifier = victim.stringifier()
         if isinstance(self, strifier):
             raise ValueError("stringifier '%s' can't handle '%s'"
                     % (self, victim.__class__))
-        return strifier(self.constant_mapper)(victim, enclosing_prec)
+        return strifier(self.constant_mapper)(victim, enclosing_prec, *args, **kwargs)
 
-    def map_constant(self, expr, enclosing_prec):
+    def map_constant(self, expr, enclosing_prec, *args, **kwargs):
         result = self.constant_mapper(expr)
 
         if not (result.startswith("(") and result.endswith(")")) \
@@ -123,72 +123,73 @@ class StringifyMapper(pymbolic.mapper.Mapper):
         else:
             return result
 
-    def map_variable(self, expr, enclosing_prec):
+    def map_variable(self, expr, enclosing_prec, *args, **kwargs):
         return expr.name
 
-    def map_function_symbol(self, expr, enclosing_prec):
+    def map_function_symbol(self, expr, enclosing_prec, *args, **kwargs):
         return expr.__class__.__name__
 
-    def map_call(self, expr, enclosing_prec):
+    def map_call(self, expr, enclosing_prec, *args, **kwargs):
         return self.format("%s(%s)",
-                self.rec(expr.function, PREC_CALL),
-                self.join_rec(", ", expr.parameters, PREC_NONE))
+                self.rec(expr.function, PREC_CALL, *args, **kwargs),
+                self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs))
 
-    def map_call_with_kwargs(self, expr, enclosing_prec):
+    def map_call_with_kwargs(self, expr, enclosing_prec, *args, **kwargs):
         args_strings = (
-                tuple(self.rec(ch, PREC_NONE) for ch in expr.parameters)
+                tuple(self.rec(ch, PREC_NONE, *args, **kwargs)
+                      for ch in expr.parameters)
                 +
-                tuple("%s=%s" % (name, self.rec(ch, PREC_NONE))
+                tuple("%s=%s" % (name, self.rec(ch, PREC_NONE, *args, **kwargs))
                     for name, ch in expr.kw_parameters.items()))
         return self.format("%s(%s)",
-                self.rec(expr.function, PREC_CALL),
+                self.rec(expr.function, PREC_CALL, *args, **kwargs),
                 ", ".join(args_strings))
 
-    def map_subscript(self, expr, enclosing_prec):
+    def map_subscript(self, expr, enclosing_prec, *args, **kwargs):
         if isinstance(expr.index, tuple):
-            index_str = self.join_rec(", ", expr.index, PREC_NONE)
+            index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs)
         else:
-            index_str = self.rec(expr.index, PREC_NONE)
+            index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs)
 
         return self.parenthesize_if_needed(
                 self.format("%s[%s]",
-                    self.rec(expr.aggregate, PREC_CALL),
+                    self.rec(expr.aggregate, PREC_CALL, *args, **kwargs),
                     index_str),
                 enclosing_prec, PREC_CALL)
 
-    def map_lookup(self, expr, enclosing_prec):
+    def map_lookup(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
                 self.format("%s.%s",
-                    self.rec(expr.aggregate, PREC_CALL),
+                    self.rec(expr.aggregate, PREC_CALL, *args, **kwargs),
                     expr.name),
                 enclosing_prec, PREC_CALL)
 
-    def map_sum(self, expr, enclosing_prec):
+    def map_sum(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                self.join_rec(" + ", expr.children, PREC_SUM),
+                self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs),
                 enclosing_prec, PREC_SUM)
 
-    def map_product(self, expr, enclosing_prec):
+    def map_product(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                self.join_rec("*", expr.children, PREC_PRODUCT),
+                self.join_rec("*", expr.children, PREC_PRODUCT, *args, **kwargs),
                 enclosing_prec, PREC_PRODUCT)
 
-    def map_quotient(self, expr, enclosing_prec):
+    def map_quotient(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
                 self.format("%s / %s",
                     # space is necessary--otherwise '/*' becomes
                     # start-of-comment in C. ('*' from dereference)
-                    self.rec(expr.numerator, PREC_PRODUCT),
-                    self.rec(expr.denominator, PREC_POWER)),  # analogous to ^{-1}
+                    self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs),
+                    self.rec(expr.denominator, PREC_POWER, *args, **kwargs)),  # analogous to ^{-1}
                 enclosing_prec, PREC_PRODUCT)
 
-    def map_floor_div(self, expr, enclosing_prec):
+    def map_floor_div(self, expr, enclosing_prec, *args, **kwargs):
         # (-1) * ((-1)*x // 5) should not reassociate. Therefore raise precedence
         # on the numerator and shield against surrounding products.
 
         result = self.format("%s // %s",
-                    self.rec(expr.numerator, PREC_POWER),
-                    self.rec(expr.denominator, PREC_POWER))  # analogous to ^{-1}
+                    self.rec(expr.numerator, PREC_POWER, *args, **kwargs),
+                    self.rec(expr.denominator, PREC_POWER, *args, **kwargs))  # analogous to ^{-1}
 
         # Note ">=", not ">" as in parenthesize_if_needed().
         if enclosing_prec >= PREC_PRODUCT:
@@ -196,101 +197,101 @@ class StringifyMapper(pymbolic.mapper.Mapper):
         else:
             return result
 
-    def map_power(self, expr, enclosing_prec):
+    def map_power(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
                 self.format("%s**%s",
-                    self.rec(expr.base, PREC_POWER),
-                    self.rec(expr.exponent, PREC_POWER)),
+                    self.rec(expr.base, PREC_POWER, *args, **kwargs),
+                    self.rec(expr.exponent, PREC_POWER, *args, **kwargs)),
                 enclosing_prec, PREC_POWER)
 
-    def map_remainder(self, expr, enclosing_prec):
+    def map_remainder(self, expr, enclosing_prec, *args, **kwargs):
         return self.format("(%s %% %s)",
-                    self.rec(expr.numerator, PREC_PRODUCT),
-                    self.rec(expr.denominator, PREC_POWER))  # analogous to ^{-1}
+                    self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs),
+                    self.rec(expr.denominator, PREC_POWER, *args, **kwargs))  # analogous to ^{-1}
 
-    def map_polynomial(self, expr, enclosing_prec):
+    def map_polynomial(self, expr, enclosing_prec, *args, **kwargs):
         from pymbolic.primitives import flattened_sum
         return self.rec(flattened_sum(
             [coeff*expr.base**exp for exp, coeff in expr.data[::-1]]),
-            enclosing_prec)
+            enclosing_prec, *args, **kwargs)
 
-    def map_left_shift(self, expr, enclosing_prec):
+    def map_left_shift(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
                 self.format("%s << %s",
-                    self.rec(expr.shiftee, PREC_SHIFT),
-                    self.rec(expr.shift, PREC_SHIFT)),
+                    self.rec(expr.shiftee, PREC_SHIFT, *args, **kwargs),
+                    self.rec(expr.shift, PREC_SHIFT, *args, **kwargs)),
                 enclosing_prec, PREC_SHIFT)
 
-    def map_right_shift(self, expr, enclosing_prec):
+    def map_right_shift(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
                 self.format("%s >> %s",
-                    self.rec(expr.shiftee, PREC_SHIFT),
-                    self.rec(expr.shift, PREC_SHIFT)),
+                    self.rec(expr.shiftee, PREC_SHIFT, *args, **kwargs),
+                    self.rec(expr.shift, PREC_SHIFT, *args, **kwargs)),
                 enclosing_prec, PREC_SHIFT)
 
-    def map_bitwise_not(self, expr, enclosing_prec):
+    def map_bitwise_not(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                "~" + self.rec(expr.child, PREC_UNARY),
+                "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
                 enclosing_prec, PREC_UNARY)
 
-    def map_bitwise_or(self, expr, enclosing_prec):
+    def map_bitwise_or(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                self.join_rec(" | ", expr.children, PREC_BITWISE_OR),
+                self.join_rec(" | ", expr.children, PREC_BITWISE_OR, *args, **kwargs),
                 enclosing_prec, PREC_BITWISE_OR)
 
-    def map_bitwise_xor(self, expr, enclosing_prec):
+    def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                self.join_rec(" ^ ", expr.children, PREC_BITWISE_XOR),
+                self.join_rec(" ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs),
                 enclosing_prec, PREC_BITWISE_XOR)
 
-    def map_bitwise_and(self, expr, enclosing_prec):
+    def map_bitwise_and(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                self.join_rec(" ^ ", expr.children, PREC_BITWISE_AND),
+                self.join_rec(" ^ ", expr.children, PREC_BITWISE_AND, *args, **kwargs),
                 enclosing_prec, PREC_BITWISE_AND)
 
-    def map_comparison(self, expr, enclosing_prec):
+    def map_comparison(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
                 self.format("%s %s %s",
-                    self.rec(expr.left, PREC_COMPARISON),
+                    self.rec(expr.left, PREC_COMPARISON, *args, **kwargs),
                     expr.operator,
-                    self.rec(expr.right, PREC_COMPARISON)),
+                    self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)),
                 enclosing_prec, PREC_COMPARISON)
 
-    def map_logical_not(self, expr, enclosing_prec):
+    def map_logical_not(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                "not " + self.rec(expr.child, PREC_UNARY),
+                "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
                 enclosing_prec, PREC_UNARY)
 
-    def map_logical_or(self, expr, enclosing_prec):
+    def map_logical_or(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                self.join_rec(" or ", expr.children, PREC_LOGICAL_OR),
+                self.join_rec(" or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
                 enclosing_prec, PREC_LOGICAL_OR)
 
-    def map_logical_and(self, expr, enclosing_prec):
+    def map_logical_and(self, expr, enclosing_prec, *args, **kwargs):
         return self.parenthesize_if_needed(
-                self.join_rec(" and ", expr.children, PREC_LOGICAL_AND),
+                self.join_rec(" and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs),
                 enclosing_prec, PREC_LOGICAL_AND)
 
-    def map_list(self, expr, enclosing_prec):
-        return self.format("[%s]", self.join_rec(", ", expr, PREC_NONE))
+    def map_list(self, expr, enclosing_prec, *args, **kwargs):
+        return self.format("[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs))
 
     map_vector = map_list
 
-    def map_tuple(self, expr, enclosing_prec):
-        el_str = ", ".join(self.rec(child, PREC_NONE) for child in expr)
+    def map_tuple(self, expr, enclosing_prec, *args, **kwargs):
+        el_str = ", ".join(self.rec(child, PREC_NONE, *args, **kwargs) for child in expr)
         if len(expr) == 1:
             el_str += ","
 
         return "(%s)" % el_str
 
-    def map_numpy_array(self, expr, enclosing_prec):
+    def map_numpy_array(self, expr, enclosing_prec, *args, **kwargs):
         import numpy
 
         from pytools import indices_in_shape
         str_array = numpy.zeros(expr.shape, dtype="object")
         max_length = 0
         for i in indices_in_shape(expr.shape):
-            s = self.rec(expr[i], PREC_NONE)
+            s = self.rec(expr[i], PREC_NONE, *args, **kwargs)
             max_length = max(len(s), max_length)
             str_array[i] = s.replace("\n", "\n  ")
 
@@ -306,10 +307,10 @@ class StringifyMapper(pymbolic.mapper.Mapper):
             else:
                 return "array(\n%s)" % "".join(lines)
 
-    def map_multivector(self, expr, enclosing_prec):
-        return expr.stringify(self.rec, enclosing_prec)
+    def map_multivector(self, expr, enclosing_prec, *args, **kwargs):
+        return expr.stringify(self.rec, enclosing_prec, *args, **kwargs)
 
-    def map_common_subexpression(self, expr, enclosing_prec):
+    def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs):
         from pymbolic.primitives import CommonSubexpression
         if type(expr) is CommonSubexpression:
             type_name = "CSE"
@@ -317,49 +318,49 @@ class StringifyMapper(pymbolic.mapper.Mapper):
             type_name = type(expr).__name__
 
         return self.format("%s(%s)",
-                type_name, self.rec(expr.child, PREC_NONE))
+                type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs))
 
-    def map_if(self, expr, enclosing_prec):
+    def map_if(self, expr, enclosing_prec, *args, **kwargs):
         return "If(%s, %s, %s)" % (
-                self.rec(expr.condition, PREC_NONE),
-                self.rec(expr.then, PREC_NONE),
-                self.rec(expr.else_, PREC_NONE))
+                self.rec(expr.condition, PREC_NONE, *args, **kwargs),
+                self.rec(expr.then, PREC_NONE, *args, **kwargs),
+                self.rec(expr.else_, PREC_NONE, *args, **kwargs))
 
-    def map_if_positive(self, expr, enclosing_prec):
+    def map_if_positive(self, expr, enclosing_prec, *args, **kwargs):
         return "If(%s > 0, %s, %s)" % (
-                self.rec(expr.criterion, PREC_NONE),
-                self.rec(expr.then, PREC_NONE),
-                self.rec(expr.else_, PREC_NONE))
+                self.rec(expr.criterion, PREC_NONE, *args, **kwargs),
+                self.rec(expr.then, PREC_NONE, *args, **kwargs),
+                self.rec(expr.else_, PREC_NONE, *args, **kwargs))
 
-    def map_min(self, expr, enclosing_prec):
+    def map_min(self, expr, enclosing_prec, *args, **kwargs):
         what = type(expr).__name__.lower()
         return self.format("%s(%s)",
-                what, self.join_rec(", ", expr.children, PREC_NONE))
+                what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs))
 
     map_max = map_min
 
-    def map_derivative(self, expr, enclosing_prec):
+    def map_derivative(self, expr, enclosing_prec, *args, **kwargs):
         derivs = " ".join(
                 "d/d%s" % v
                 for v in expr.variables)
 
         return "%s %s" % (
-                derivs, self.rec(expr.child, PREC_PRODUCT))
+                derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs))
 
-    def map_substitution(self, expr, enclosing_prec):
+    def map_substitution(self, expr, enclosing_prec, *args, **kwargs):
         substs = ", ".join(
-                "%s=%s" % (name, self.rec(val, PREC_NONE))
+                "%s=%s" % (name, self.rec(val, PREC_NONE, *args, **kwargs))
                 for name, val in zip(expr.variables, expr.values))
 
-        return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE), substs)
+        return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE, *args, **kwargs), substs)
 
-    def map_slice(self, expr, enclosing_prec):
+    def map_slice(self, expr, enclosing_prec, *args, **kwargs):
         children = []
         for child in expr.children:
             if child is None:
                 children.append("")
             else:
-                children.append(self.rec(child, PREC_NONE))
+                children.append(self.rec(child, PREC_NONE, *args, **kwargs))
 
         return self.parenthesize_if_needed(
                 self.join(":", children),
@@ -367,13 +368,13 @@ class StringifyMapper(pymbolic.mapper.Mapper):
 
     # }}}
 
-    def __call__(self, expr, prec=PREC_NONE):
+    def __call__(self, expr, prec=PREC_NONE, *args, **kwargs):
         """Return a string corresponding to *expr*. If the enclosing
         precedence level *prec* is higher than *prec* (see :ref:`prec-constants`),
         parenthesize the result.
         """
 
-        return pymbolic.mapper.Mapper.__call__(self, expr, prec)
+        return pymbolic.mapper.Mapper.__call__(self, expr, prec, *args, **kwargs)
 
 # }}}
 
@@ -404,7 +405,7 @@ class CSESplittingStringifyMapperMixin(object):
     See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example
     of the use of this mix-in.
     """
-    def map_common_subexpression(self, expr, enclosing_prec):
+    def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs):
         try:
             self.cse_to_name
         except AttributeError:
@@ -415,7 +416,7 @@ class CSESplittingStringifyMapperMixin(object):
         try:
             cse_name = self.cse_to_name[expr.child]
         except KeyError:
-            str_child = self.rec(expr.child, PREC_NONE)
+            str_child = self.rec(expr.child, PREC_NONE, *args, **kwargs)
 
             if expr.prefix is not None:
                 def generate_cse_names():
@@ -456,15 +457,15 @@ class SortingStringifyMapper(StringifyMapper):
         StringifyMapper.__init__(self, constant_mapper)
         self.reverse = reverse
 
-    def map_sum(self, expr, enclosing_prec):
-        entries = [self.rec(i, PREC_SUM) for i in expr.children]
+    def map_sum(self, expr, enclosing_prec, *args, **kwargs):
+        entries = [self.rec(i, PREC_SUM, *args, **kwargs) for i in expr.children]
         entries.sort(reverse=self.reverse)
         return self.parenthesize_if_needed(
                 self.join(" + ", entries),
                 enclosing_prec, PREC_SUM)
 
-    def map_product(self, expr, enclosing_prec):
-        entries = [self.rec(i, PREC_PRODUCT) for i in expr.children]
+    def map_product(self, expr, enclosing_prec, *args, **kwargs):
+        entries = [self.rec(i, PREC_PRODUCT, *args, **kwargs) for i in expr.children]
         entries.sort(reverse=self.reverse)
         return self.parenthesize_if_needed(
                 self.join("*", entries),
@@ -480,7 +481,7 @@ class SimplifyingSortingStringifyMapper(StringifyMapper):
         StringifyMapper.__init__(self, constant_mapper)
         self.reverse = reverse
 
-    def map_sum(self, expr, enclosing_prec):
+    def map_sum(self, expr, enclosing_prec, *args, **kwargs):
         def get_neg_product(expr):
             from pymbolic.primitives import is_zero, Product
 
@@ -500,9 +501,9 @@ class SimplifyingSortingStringifyMapper(StringifyMapper):
         for ch in expr.children:
             neg_prod = get_neg_product(ch)
             if neg_prod is not None:
-                negatives.append(self.rec(neg_prod, PREC_PRODUCT))
+                negatives.append(self.rec(neg_prod, PREC_PRODUCT, *args, **kwargs))
             else:
-                positives.append(self.rec(ch, PREC_SUM))
+                positives.append(self.rec(ch, PREC_SUM, *args, **kwargs))
 
         positives.sort(reverse=self.reverse)
         positives = " + ".join(positives)
@@ -514,7 +515,7 @@ class SimplifyingSortingStringifyMapper(StringifyMapper):
 
         return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM)
 
-    def map_product(self, expr, enclosing_prec):
+    def map_product(self, expr, enclosing_prec, *args, **kwargs):
         entries = []
         i = 0
         from pymbolic.primitives import is_zero
@@ -526,10 +527,10 @@ class SimplifyingSortingStringifyMapper(StringifyMapper):
                 # Otherwise two unary minus signs merge into a pre-decrement.
                 entries.append(
                         self.format(
-                            "- %s", self.rec(expr.children[i+1], PREC_UNARY)))
+                            "- %s", self.rec(expr.children[i+1], PREC_UNARY, *args, **kwargs)))
                 i += 2
             else:
-                entries.append(self.rec(child, PREC_PRODUCT))
+                entries.append(self.rec(child, PREC_PRODUCT, *args, **kwargs))
                 i += 1
 
         entries.sort(reverse=self.reverse)