From c50ce18f0604ced92d897ccaf446c993f20aecab Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Fri, 28 Nov 2014 13:35:30 -0600 Subject: [PATCH] Allow overriding variable class in make_sym_vector --- pymbolic/primitives.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 24be3ec..f9caa48 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -1571,17 +1571,21 @@ def make_common_subexpression(field, prefix=None, scope=None): return CommonSubexpression(field, prefix, scope) -def make_sym_vector(name, components): +def make_sym_vector(name, components, var_class=None): """Return an object array of *components* subscripted - :class:`Variable` instances. + :class:`Variable` (or subclass) instances. - :param components: The number of components in the vector. + :arg components: The number of components in the vector. + :arg var_class: The :class:`Variable` subclass to use for instantiating + the scalar variables. """ if isinstance(components, int): components = list(range(components)) + if var_class is None: + var_class = Variable from pytools.obj_array import join_fields - vfld = Variable(name) + vfld = var_class(name) return join_fields(*[vfld.index(i) for i in components]) -- GitLab