From 9943f855cb9966656e3c999af8c2d18be77b73c8 Mon Sep 17 00:00:00 2001
From: xywei <wxy0516@gmail.com>
Date: Sat, 23 May 2020 19:22:52 -0500
Subject: [PATCH] Add basic contracts

---
 pytato/array.py    | 25 +++++++++++++++++++------
 pytato/contract.py | 44 ++++++++++++++++++++++++++++++++++++++++++++
 setup.py           |  1 +
 3 files changed, 64 insertions(+), 6 deletions(-)
 create mode 100644 pytato/contract.py

diff --git a/pytato/array.py b/pytato/array.py
index 4a52ca8..beb6737 100644
--- a/pytato/array.py
+++ b/pytato/array.py
@@ -34,6 +34,9 @@ is referenced from :class:`DataArray`.
 
 import collections.abc
 from pytools import single_valued, is_single_valued
+from contracts import contract
+from pymbolic.primitives import Expression  # noqa
+from pytato.contract import c_identifier, ArrayInterface  # noqa
 
 
 class DottedName:
@@ -41,13 +44,15 @@ class DottedName:
     .. attribute:: name_parts
 
         A tuple of strings, each of which is a valid
-        Python identifier.
+        C identifier (non-Unicode Python identifier).
 
     The name (at least morally) exists in the
     name space defined by the Python module system.
     It need not necessarily identify an importable
     object.
     """
+
+    @contract(name_parts='list[>0](str,c_identifier)')
     def __init__(self, name_parts):
         self.name_parts = name_parts
 
@@ -65,13 +70,14 @@ class Namespace:
     def __init__(self):
         self.symbol_table = {}
 
+    @contract(name='str,c_identifier', value=ArrayInterface)
     def assign(self, name, value):
         if name in self.symbol_table:
             raise ValueError(f"'{name}' is already assigned")
         self.symbol_table[name] = value
 
 
-class Array:
+class Array(ArrayInterface):
     """
     A base class (abstract interface +
     supplemental functionality) for lazily
@@ -150,6 +156,7 @@ class Array:
         purposefully so.
     """
 
+    @contract(namespace=Namespace, name='str,c_identifier')
     def __init__(self, namespace, name, tags=None):
         if tags is None:
             tags = {}
@@ -172,6 +179,7 @@ class Array:
     def ndim(self):
         return len(self.shape)
 
+    @contract(dotted_name=DottedName)
     def with_tag(self, dotted_name, args=None):
         """
         Returns a copy of *self* tagged with *dotted_name*
@@ -182,9 +190,11 @@ class Array:
         if args is None:
             pass
 
+    @contract(dotted_name=DottedName)
     def without_tag(self, dotted_name):
         pass
 
+    @contract(name='str,c_identifier')
     def with_name(self, name):
         self.namespace.assign_name(name, self)
         return self.copy(name=name)
@@ -196,7 +206,7 @@ class Array:
 
 
 class DictOfNamedArrays(collections.abc.Mapping):
-    """A container that maps valid Python identifiers
+    """A container that maps valid C identifiers
     to instances of :class:`Array`. May occur as a result
     type of array computations.
 
@@ -211,12 +221,10 @@ class DictOfNamedArrays(collections.abc.Mapping):
         arithmetic.
     """
 
+    @contract(data='dict((str,c_identifier):$ArrayInterface)')
     def __init__(self, data):
         self._data = data
-        # TODO: Check that keys are valid Python identifiers
 
-        if not is_single_valued(ary.target for ary in data.values()):
-            raise ValueError("arrays do not have same target")
         if not is_single_valued(ary.namespace for ary in data.values()):
             raise ValueError("arrays do not have same namespace")
 
@@ -224,9 +232,11 @@ class DictOfNamedArrays(collections.abc.Mapping):
     def namespace(self):
         return single_valued(ary.namespace for ary in self._data.values())
 
+    @contract(name='str,c_identifier')
     def __contains__(self, name):
         return name in self._data
 
+    @contract(name='str,c_identifier')
     def __getitem__(self, name):
         return self._data[name]
 
@@ -301,6 +311,9 @@ class Placeholder(Array):
         # Not tied to this, open for discussion about how to implement this.
         return self._shape
 
+    # TODO: proper contracts for the shape
+    @contract(namespace=Namespace, name='str,c_identifier',
+              shape=tuple, tags='None | dict($DottedName:$DottedName)')
     def __init__(self, namespace, name, shape, tags=None):
         if name is None:
             raise ValueError("PlaceholderArray instances must have a name")
diff --git a/pytato/contract.py b/pytato/contract.py
new file mode 100644
index 0000000..941bad9
--- /dev/null
+++ b/pytato/contract.py
@@ -0,0 +1,44 @@
+__copyright__ = """
+Copyright (C) 2020 Andreas Kloeckner
+Copyright (C) 2020 Matt Wala
+Copyright (C) 2020 Xiaoyu Wei
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import re
+from contracts import new_contract, ContractsMeta
+
+
+C_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
+
+
+@new_contract
+def c_identifier(x):
+    """Tests whether a string is a valid variable name in C.
+    """
+    return re.match(C_IDENTIFIER, x) is not None
+
+
+class ArrayInterface():
+    """Abstract class for types implementing the Array interface.
+    """
+    __metaclass__ = ContractsMeta
diff --git a/setup.py b/setup.py
index c56a780..153fe38 100644
--- a/setup.py
+++ b/setup.py
@@ -35,6 +35,7 @@ setup(name="pytato",
 
       install_requires=[
           "loo.py",
+          "pycontracts",
           ],
 
       author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei",
-- 
GitLab