Skip to content
Snippets Groups Projects

test_opencl_counter_pytential.py

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Hao Gao
    snippetfile1.txt 2.06 KiB
    import numpy as np
    import pyopencl as cl
    import pytest
    
    TARGET_ORDER = 8
    OVSMP_FACTOR = 5
    QBX_ORDER = 5
    FMM_ORDER = 10
    TCF = 0.9
    
    cl_context = cl.create_some_context()
    
    
    def get_starfish_mesh(npoints):
        from meshmode.mesh.generation import starfish, make_curve_mesh
        return make_curve_mesh(starfish, np.linspace(0, 1, npoints), order=TARGET_ORDER)
    
    
    def get_lpot_source(context, mesh):
        from meshmode.discretization import Discretization
        from meshmode.discretization.poly_element import (
            InterpolatoryQuadratureSimplexGroupFactory
        )
        pre_density_discr = Discretization(
            context,
            mesh,
            InterpolatoryQuadratureSimplexGroupFactory(TARGET_ORDER)
        )
    
        from pytential.qbx import QBXLayerPotentialSource
        qbx, _ = QBXLayerPotentialSource(
            pre_density_discr,
            OVSMP_FACTOR * TARGET_ORDER,
            qbx_order=QBX_ORDER,
            fmm_order=FMM_ORDER,
            _expansion_stick_out_factor=TCF
        ).with_refinement()
    
        return qbx
    
    
    @pytest.mark.parametrize("off_surface", (True, False))
    def test_opencl_implementation_against_python(mesh, off_surface):
        lpot_source = get_lpot_source(cl_context, mesh)
    
        # {{{ Construct targets
    
        if off_surface:
            from pytential.target import PointsTarget
            from boxtree.tools import make_uniform_particle_array
    
            ntargets = 10 ** 3
            dim = mesh.ambient_dim
    
            queue = cl.CommandQueue(cl_context)
            targets = PointsTarget(
                make_uniform_particle_array(queue, ntargets, dim, np.float)
            )
    
            target_discrs_and_qbx_sides = ((targets, 0),)
    
        else:
            targets = lpot_source.density_discr
            target_discrs_and_qbx_sides = ((targets, 1),)
    
        # }}}
    
        geo_data = lpot_source.qbx_fmm_geometry_data(
            target_discrs_and_qbx_sides=target_discrs_and_qbx_sides
        )
    
    
    def main():
        mesh = get_starfish_mesh(100)
        off_surface = True
        test_opencl_implementation_against_python(mesh, off_surface)
    
    
    if __name__ == "__main__":
        import sys
        if len(sys.argv) > 1:
            exec(sys.argv[1])
        else:
            main()
    
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment