diff --git a/pymbolic/geometric_algebra/__init__.py b/pymbolic/geometric_algebra/__init__.py index 6eca83a1f48c34cfa160db945b3fef14af8c6f40..ac9df36cf27a52c516ed64c6cb6411a5aebd968c 100644 --- a/pymbolic/geometric_algebra/__init__.py +++ b/pymbolic/geometric_algebra/__init__.py @@ -542,7 +542,10 @@ class MultiVector(Generic[CoeffT]): def __init__( self, - data: Mapping[tuple[int, ...] | int, CoeffT] | np.ndarray | CoeffT, + data: (Mapping[int, CoeffT] + | Mapping[tuple[int, ...], CoeffT] + | np.ndarray + | CoeffT), space: Space | None = None ) -> None: """ @@ -562,36 +565,39 @@ class MultiVector(Generic[CoeffT]): works when a :class:`numpy.ndarray` is being passed for *data*. """ - dimensions = None - + data_dict: Mapping if isinstance(data, np.ndarray): if len(data.shape) != 1: - raise ValueError("only numpy vectors (not higher-rank objects) " - "are supported for 'data'") + raise ValueError( + "Only numpy vectors (not higher-rank objects) " + f"are supported for 'data': shape {data.shape}") + dimensions, = data.shape - data = {(i,): xi for i, xi in enumerate(data)} - elif isinstance(data, dict): - pass + data_dict = {(i,): cast(CoeffT, xi) for i, xi in enumerate(data)} + + if space is None: + space = get_euclidean_space(dimensions) + + if space.dimensions != dimensions: + raise ValueError( + "Dimension of 'space' does not match that of 'data': " + f"got {space.dimensions}d space but expected {dimensions}d") + elif isinstance(data, Mapping): + data_dict = data else: - data = {0: cast(CoeffT, data)} + data_dict = {0: cast(CoeffT, data)} if space is None: - assert isinstance(dimensions, int) - space = get_euclidean_space(dimensions) - else: - if dimensions is not None and space.dimensions != dimensions: - raise ValueError( - "dimension count of 'space' does not match that of 'data'") + raise ValueError("No 'space' provided") # {{{ normalize data to bitmaps, if needed from pytools import single_valued - from pymbolic.primitives import is_zero - if data and single_valued(isinstance(k, tuple) for k in data.keys()): + if data_dict and single_valued(isinstance(k, tuple) for k in data_dict.keys()): # data is in non-normalized non-bits tuple form new_data: dict[int, CoeffT] = {} - for basis_indices, coeff in data.items(): + for basis_indices, coeff in data_dict.items(): assert isinstance(basis_indices, tuple) bits, sign = space.bits_and_sign(basis_indices) @@ -604,7 +610,7 @@ class MultiVector(Generic[CoeffT]): else: new_data[bits] = new_coeff else: - new_data = cast(dict[int, CoeffT], data) + new_data = cast(dict[int, CoeffT], data_dict) # }}}