Skip to content
Snippets Groups Projects
Commit 5d57dd29 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

make_sym_{array,vector}: var_class -> var_factory

parent 87b759df
No related branches found
No related tags found
No related merge requests found
Pipeline #347 passed with stage
...@@ -160,6 +160,7 @@ In particular, :mod:`numpy` object arrays are useful for capturing ...@@ -160,6 +160,7 @@ In particular, :mod:`numpy` object arrays are useful for capturing
vectors and matrices of :mod:`pymbolic` objects. vectors and matrices of :mod:`pymbolic` objects.
.. autofunction:: make_sym_vector .. autofunction:: make_sym_vector
.. autofunction:: make_sym_array
""" """
...@@ -1648,26 +1649,30 @@ def make_common_subexpression(field, prefix=None, scope=None): ...@@ -1648,26 +1649,30 @@ def make_common_subexpression(field, prefix=None, scope=None):
return CommonSubexpression(field, prefix, scope) return CommonSubexpression(field, prefix, scope)
def make_sym_vector(name, components, var_class=None): def make_sym_vector(name, components, var_factory=None):
"""Return an object array of *components* subscripted """Return an object array of *components* subscripted
:class:`Variable` (or subclass) instances. :class:`Variable` (or subclass) instances.
:arg components: The number of components in the vector. :arg components: The number of components in the vector.
:arg var_class: The :class:`Variable` subclass to use for instantiating :arg var_factory: The :class:`Variable` subclass to use for instantiating
the scalar variables. the scalar variables.
""" """
if var_factory is None:
var_factory = Variable
if isinstance(components, int): if isinstance(components, int):
components = list(range(components)) components = list(range(components))
if var_class is None:
var_class = Variable
from pytools.obj_array import join_fields from pytools.obj_array import join_fields
vfld = var_class(name) vfld = var_factory(name)
return join_fields(*[vfld.index(i) for i in components]) return join_fields(*[vfld.index(i) for i in components])
def make_sym_array(name, shape): def make_sym_array(name, shape, var_factory=None):
vfld = Variable(name) if var_factory is None:
var_factory = Variable
vfld = var_factory(name)
if shape == (): if shape == ():
return vfld return vfld
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment