Skip to content
Snippets Groups Projects
test_grudge.py 35.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • Matt Wala's avatar
    Matt Wala committed
        from grudge.symbolic.mappers import BoundOperatorCollector
    
        class TestBoundOperatorCollector(BoundOperatorCollector):
    
            def map_test_operator(self, expr):
                return self.map_operator(expr)
    
        v0 = sym.var("v0")
        ob0 = sym.OperatorBinding(TestOperator(), v0)
    
        v1 = sym.var("v1")
        ob1 = sym.OperatorBinding(TestOperator(), v1)
    
        # The output order isn't significant, but it should always be the same.
        assert list(TestBoundOperatorCollector(TestOperator)(ob0 + ob1)) == [ob0, ob1]
    
    
    Matt Wala's avatar
    Matt Wala committed
    
    
    def test_bessel(ctx_factory):
    
        cl_ctx = ctx_factory()
    
        queue = cl.CommandQueue(cl_ctx)
    
        actx = PyOpenCLArrayContext(queue)
    
    
        dims = 2
    
        from meshmode.mesh.generation import generate_regular_rect_mesh
        mesh = generate_regular_rect_mesh(
                a=(0.1,)*dims,
                b=(1.0,)*dims,
                n=(8,)*dims)
    
    
        discr = DGDiscretizationWithBoundaries(actx, mesh, order=3)
    
    
        nodes = sym.nodes(dims)
        r = sym.cse(sym.sqrt(nodes[0]**2 + nodes[1]**2))
    
        # https://dlmf.nist.gov/10.6.1
        n = 3
        bessel_zero = (
                sym.bessel_j(n+1, r)
                + sym.bessel_j(n-1, r)
                - 2*n/r * sym.bessel_j(n, r))
    
    
        z = bind(discr, sym.norm(2, bessel_zero))(actx)
    
    def test_external_call(ctx_factory):
        cl_ctx = ctx_factory()
        queue = cl.CommandQueue(cl_ctx)
    
        actx = PyOpenCLArrayContext(queue)
    
    
        def double(queue, x):
            return 2 * x
    
        from meshmode.mesh.generation import generate_regular_rect_mesh
    
        dims = 2
    
        mesh = generate_regular_rect_mesh(a=(0,) * dims, b=(1,) * dims, n=(4,) * dims)
    
        discr = DGDiscretizationWithBoundaries(actx, mesh, order=1)
    
    
        ones = sym.Ones(sym.DD_VOLUME)
        op = (
                ones * 3
    
                + sym.FunctionSymbol("double")(ones))
    
    
        from grudge.function_registry import (
                base_function_registry, register_external_function)
    
        freg = register_external_function(
                base_function_registry,
                "double",
                implementation=double,
                dd=sym.DD_VOLUME)
    
        bound_op = bind(discr, op, function_registry=freg)
    
    
        result = bound_op(actx, double=double)
        assert actx.to_numpy(flatten(result) == 5).all()
    
    @pytest.mark.parametrize("array_type", ["scalar", "vector"])
    def test_function_symbol_array(ctx_factory, array_type):
        ctx = ctx_factory()
        queue = cl.CommandQueue(ctx)
    
        actx = PyOpenCLArrayContext(queue)
    
    
        from meshmode.mesh.generation import generate_regular_rect_mesh
        dim = 2
        mesh = generate_regular_rect_mesh(
                a=(-0.5,)*dim, b=(0.5,)*dim,
                n=(8,)*dim, order=4)
    
        discr = DGDiscretizationWithBoundaries(actx, mesh, order=4)
        volume_discr = discr.discr_from_dd(sym.DD_VOLUME)
        ndofs = sum(grp.ndofs for grp in volume_discr.groups)
    
    
        import pyopencl.clrandom        # noqa: F401
        if array_type == "scalar":
            sym_x = sym.var("x")
    
            x = unflatten(actx, volume_discr,
                    cl.clrandom.rand(queue, ndofs, dtype=np.float))
    
        elif array_type == "vector":
            sym_x = sym.make_sym_array("x", dim)
            x = make_obj_array([
    
                unflatten(actx, volume_discr,
                    cl.clrandom.rand(queue, ndofs, dtype=np.float))
    
                for _ in range(dim)
                ])
        else:
            raise ValueError("unknown array type")
    
    
        norm = bind(discr, sym.norm(2, sym_x))(x=x)
    
        assert isinstance(norm, float)
    
    
    # You can test individual routines by typing
    
    # $ python test_grudge.py 'test_routine()'
    
    
    if __name__ == "__main__":
        import sys
        if len(sys.argv) > 1:
            exec(sys.argv[1])
        else:
    
            pytest.main([__file__])
    
    
    # vim: fdm=marker