diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 921337583f9898aa102b7f9f6cda70d06e1cb52b..dbfd8e9fe9bb0d469d0ad612157e14495c240b3f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,14 +26,14 @@ Python 2.7 Conda: reports: junit: test/pytest.xml -Python 3.7: +Python 3: script: - - py_version=3.7 + - PY_EXE=python3 - EXTRA_INSTALL="numpy sympy pexpect" - curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/master/build-and-test-py-project.sh - ". ./build-and-test-py-project.sh" tags: - - python3.7 + - python3 - maxima except: - tags @@ -41,7 +41,7 @@ Python 3.7: reports: junit: test/pytest.xml -Python 3.6 Conda: +Python 3 Conda: script: - CONDA_ENVIRONMENT=.test-py3.yml - curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/master/build-and-test-py-project-within-miniconda.sh @@ -54,7 +54,7 @@ Python 3.6 Conda: reports: junit: test/pytest.xml -Python 3.6 Apple: +Python 3 Conda Apple: script: - CONDA_ENVIRONMENT=.test-py3.yml - curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/master/build-and-test-py-project-within-miniconda.sh @@ -70,11 +70,11 @@ Python 3.6 Apple: Pylint: script: - EXTRA_INSTALL="numpy sympy symengine scipy pexpect" - - py_version=3.7 + - PY_EXE=python3 - curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh - ". ./prepare-and-run-pylint.sh pymbolic test/test_*.py" tags: - - python3.7 + - python3 except: - tags @@ -93,6 +93,6 @@ Flake8: - curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-flake8.sh - ". ./prepare-and-run-flake8.sh pymbolic test" tags: - - python3.5 + - python3 except: - tags diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 018cbaae75ac93009fbb62901df9baf7b8ce653a..8bc804249a427014faedbd5a0a48779ad6b699ef 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -1696,18 +1696,24 @@ def make_common_subexpression(field, prefix=None, scope=None): return CommonSubexpression(field, prefix, scope) -def make_sym_vector(name, components, var_factory=None): +def make_sym_vector(name, components, var_factory=Variable): """Return an object array of *components* subscripted :class:`Variable` (or subclass) instances. - :arg components: The number of components in the vector. - :arg var_factory: The :class:`Variable` subclass to use for instantiating - the scalar variables. - """ - if var_factory is None: - var_factory = Variable + :arg components: Either a list of indices, or an integer representing the + number of indices. + :arg var_factory: The :class:`Variable` subclass to + use for instantiating the scalar variables. + + For example, this creates a vector with three components:: + + >>> make_sym_vector("vec", 3) + array([Subscript(Variable('vec'), 0), Subscript(Variable('vec'), 1), + Subscript(Variable('vec'), 2)], dtype=object) - if isinstance(components, int): + """ + from numbers import Integral + if isinstance(components, Integral): components = list(range(components)) from pytools.obj_array import join_fields @@ -1715,10 +1721,7 @@ def make_sym_vector(name, components, var_factory=None): return join_fields(*[vfld.index(i) for i in components]) -def make_sym_array(name, shape, var_factory=None): - if var_factory is None: - var_factory = Variable - +def make_sym_array(name, shape, var_factory=Variable): vfld = var_factory(name) if shape == (): return vfld diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 2cc3c82738d5a34c28b7f3ff61b1351cb7e0d581..4b358ab9f95e9702459cb5942e6935ccf6590549 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -574,6 +574,15 @@ def test_flop_counter(): assert CSEAwareFlopCounter()(expr) == 4 + 2 +def test_make_sym_vector(): + numpy = pytest.importorskip("numpy") + from pymbolic.primitives import make_sym_vector + + assert len(make_sym_vector("vec", 2)) == 2 + assert len(make_sym_vector("vec", numpy.int32(2))) == 2 + assert len(make_sym_vector("vec", [1, 2, 3])) == 3 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: