diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 785feafba74cc88b4b98bf13daf8487e57e03a62..18e6cefc2c07ef8bfe614076be5f5728e101f98b 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -26,7 +26,11 @@ from functools import reduce import pymbolic.primitives as prim from pymbolic import parse from pytools.lex import ParseError -from pymbolic.mapper import IdentityMapper +from testlib import generate_random_expression + + +from pymbolic.mapper import IdentityMapper, WalkMapper +from pymbolic.mapper.dependency import DependencyMapper, CachedDependencyMapper import logging logger = logging.getLogger(__name__) @@ -332,8 +336,6 @@ def test_mappers(): for expr in [ f(x, (y, z), name=z**2) ]: - from pymbolic.mapper import WalkMapper - from pymbolic.mapper.dependency import DependencyMapper str(expr) IdentityMapper()(expr) WalkMapper()(expr) @@ -347,7 +349,6 @@ def test_mappers(): def test_func_dep_consistency(): from pymbolic import var - from pymbolic.mapper.dependency import DependencyMapper f = var("f") x = var("x") dep_map = DependencyMapper(include_calls="descend_args") @@ -863,6 +864,70 @@ def test_equality_complexity(): # }}} +# {{{ test_cached_mapper_memoizes + +class InCacheVerifier(WalkMapper): + def __init__(self, cached_mapper, walk_call_functions=True): + super().__init__() + self.cached_mapper = cached_mapper + self.walk_call_functions = walk_call_functions + + def post_visit(self, expr): + if isinstance(expr, prim.Expression): + assert (self.cached_mapper.get_cache_key(expr) + in self.cached_mapper._cache) + + def map_call(self, expr): + if not self.visit(expr): + return + + if self.walk_call_functions: + self.rec(expr.function) + + for child in expr.parameters: + self.rec(child) + + self.post_visit(expr) + + +def test_cached_mapper_memoizes(): + from testlib import (AlwaysFlatteningIdentityMapper, + AlwaysFlatteningCachedIdentityMapper) + ntests = 40 + for i in range(ntests): + expr = generate_random_expression(seed=(5+i)) + + # {{{ always flattening identity mapper + + # Note: Prefer AlwaysFlatteningIdentityMapper over IdentityMapper as + # the flattening logic in IdentityMapper checks for identity across + # traversal results => leading to discrepancy b/w + # 'CachedIdentityMapper' and 'IdentityMapper' + + cached_mapper = AlwaysFlatteningCachedIdentityMapper() + uncached_mapper = AlwaysFlatteningIdentityMapper() + assert uncached_mapper(expr) == cached_mapper(expr) + verifier = InCacheVerifier(cached_mapper) + verifier(expr) + + # }}} + + # {{{ dependency mapper + + mapper = DependencyMapper(include_calls="descend_args") + cached_mapper = CachedDependencyMapper(include_calls="descend_args") + assert cached_mapper(expr) == mapper(expr) + verifier = InCacheVerifier(cached_mapper, + # dep. mapper does not go over functions + walk_call_functions=False + ) + verifier(expr) + + # }}} + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: diff --git a/test/testlib.py b/test/testlib.py new file mode 100644 index 0000000000000000000000000000000000000000..bd75c444e56dcf1670ea6ba7376cbb354c07ce0d --- /dev/null +++ b/test/testlib.py @@ -0,0 +1,125 @@ +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import numpy as np + +import pymbolic.primitives as prim +from dataclasses import dataclass, replace +from pytools import UniqueNameGenerator +from pymbolic.mapper import IdentityMapper, CachedIdentityMapper + + +@dataclass(frozen=True, eq=True) +class RandomExpressionGeneratorContext: + rng: np.random.Generator + vng: UniqueNameGenerator + current_depth: int + max_depth: int + + def with_increased_depth(self): + return replace(self, current_depth=self.current_depth+1) + + +def _generate_random_expr_inner( + context: RandomExpressionGeneratorContext) -> prim.Expression: + + if context.current_depth >= context.max_depth: + # force expression to be a leaf type + return context.rng.integers(0, 42) + + bucket = context.rng.integers(0, 100) / 100.0 + + # {{{ set some distribution of expression types + + # 'weight' is proportional to the probability of seeing an expression type + weights = [1, 1, 1, 1, 1] + expr_types = [prim.Variable, prim.Sum, prim.Product, prim.Quotient, + prim.Call] + assert len(weights) == len(expr_types) + + # }}} + + buckets = np.cumsum(weights, dtype="float64")/np.sum(weights) + + expr_type = expr_types[np.searchsorted(buckets, bucket)] + + if expr_type == prim.Variable: + return prim.Variable(context.vng("x")) + elif expr_type in [prim.Sum, prim.Product]: + left = _generate_random_expr_inner(context.with_increased_depth()) + right = _generate_random_expr_inner(context.with_increased_depth()) + return expr_type((left, right)) + elif expr_type == prim.Quotient: + num = _generate_random_expr_inner(context.with_increased_depth()) + den = _generate_random_expr_inner(context.with_increased_depth()) + return prim.Quotient(num, den) + elif expr_type == prim.Quotient: + num = _generate_random_expr_inner(context.with_increased_depth()) + den = _generate_random_expr_inner(context.with_increased_depth()) + return prim.Quotient(num, den) + elif expr_type == prim.Call: + nargs = 3 + return prim.Variable(context.vng("f"))( + *[_generate_random_expr_inner(context.with_increased_depth()) + for _ in range(nargs)]) + else: + raise NotImplementedError(expr_type) + + +def generate_random_expression(seed: int, max_depth: int = 8) -> prim.Expression: + from numpy.random import default_rng + rng = default_rng(seed) + vng = UniqueNameGenerator() + + context = RandomExpressionGeneratorContext(rng, + vng=vng, + max_depth=max_depth, + current_depth=0) + + return _generate_random_expr_inner(context) + + +# {{{ custom mappers for tests + +class _AlwaysFlatteningMixin: + def map_sum(self, expr, *args, **kwargs): + children = [self.rec(child, *args, **kwargs) for child in expr.children] + from pymbolic.primitives import flattened_sum + return flattened_sum(children) + + def map_product(self, expr, *args, **kwargs): + children = [self.rec(child, *args, **kwargs) for child in expr.children] + from pymbolic.primitives import flattened_product + return flattened_product(children) + + +class AlwaysFlatteningIdentityMapper(_AlwaysFlatteningMixin, + IdentityMapper): + pass + + +class AlwaysFlatteningCachedIdentityMapper(_AlwaysFlatteningMixin, + CachedIdentityMapper): + pass + +# }}}