diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 018cbaae75ac93009fbb62901df9baf7b8ce653a..5855f5d7158d1dc4883981ca1a7990d9bbbb528b 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -1700,14 +1700,23 @@ def make_sym_vector(name, components, var_factory=None): """Return an object array of *components* subscripted :class:`Variable` (or subclass) instances. - :arg components: The number of components in the vector. - :arg var_factory: The :class:`Variable` subclass to use for instantiating - the scalar variables. + :arg components: Either a list of indices, or an integer representing the + number of indices. + :arg var_factory: The :class:`Variable` subclass to + use for instantiating the scalar variables. + + For example, this creates a vector with three components:: + + >>> make_sym_vector("vec", 3) + array([Subscript(Variable('vec'), 0), Subscript(Variable('vec'), 1), + Subscript(Variable('vec'), 2)], dtype=object) + """ if var_factory is None: var_factory = Variable - if isinstance(components, int): + from numbers import Integral + if isinstance(components, Integral): components = list(range(components)) from pytools.obj_array import join_fields diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 2cc3c82738d5a34c28b7f3ff61b1351cb7e0d581..4b358ab9f95e9702459cb5942e6935ccf6590549 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -574,6 +574,15 @@ def test_flop_counter(): assert CSEAwareFlopCounter()(expr) == 4 + 2 +def test_make_sym_vector(): + numpy = pytest.importorskip("numpy") + from pymbolic.primitives import make_sym_vector + + assert len(make_sym_vector("vec", 2)) == 2 + assert len(make_sym_vector("vec", numpy.int32(2))) == 2 + assert len(make_sym_vector("vec", [1, 2, 3])) == 3 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: