diff --git a/pytential/qbx/target_assoc.py b/pytential/qbx/target_assoc.py index 22b2351ba82ce692cca49b5b73e1d1730604924b..5e1bcb1eb88fb7ed97ace3175c451beb8631dac8 100644 --- a/pytential/qbx/target_assoc.py +++ b/pytential/qbx/target_assoc.py @@ -35,6 +35,7 @@ import pyopencl.array # noqa from boxtree.tools import DeviceDataRecord from boxtree.area_query import AreaQueryElementwiseTemplate from boxtree.tools import InlineBinarySearch +from cgen import Enum from pytential.qbx.utils import ( QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS) @@ -85,34 +86,21 @@ logger = logging.getLogger(__name__) # {{{ kernels -TARGET_ASSOC_DEFINES = r""" -enum TargetStatus -{ - UNMARKED, - MARKED_QBX_CENTER_PENDING, - MARKED_QBX_CENTER_FOUND -}; - -enum TargetFlag -{ - INTERIOR_OR_EXTERIOR_VOLUME_TARGET = 0, - INTERIOR_SURFACE_TARGET = -1, - EXTERIOR_SURFACE_TARGET = +1, - INTERIOR_VOLUME_TARGET = -2, - EXTERIOR_VOLUME_TARGET = +2 -}; -""" - +class target_status_enum(Enum): # noqa + c_name = "TargetStatus" + dtype = np.int32 + c_value_prefix = "" -class target_status_enum(object): # noqa - # NOTE: Must match "enum TargetStatus" above UNMARKED = 0 MARKED_QBX_CENTER_PENDING = 1 MARKED_QBX_CENTER_FOUND = 2 -class target_flag_enum(object): # noqa - # NOTE: Must match "enum TargetFlag" above +class target_flag_enum(Enum): # noqa + c_name = "TargetFlag" + dtype = np.int32 + c_value_prefix = "" + INTERIOR_OR_EXTERIOR_VOLUME_TARGET = 0 INTERIOR_SURFACE_TARGET = -1 EXTERIOR_SURFACE_TARGET = +1 @@ -120,6 +108,16 @@ class target_flag_enum(object): # noqa EXTERIOR_VOLUME_TARGET = +2 +def _generate_enum_code(enum): + return "\n".join(enum.generate()) + + +TARGET_ASSOC_DEFINES = "".join([ + _generate_enum_code(target_status_enum), + _generate_enum_code(target_flag_enum), +]) + + QBX_TARGET_MARKER = AreaQueryElementwiseTemplate( extra_args=r""" /* input */