diff --git a/sumpy/kernel.py b/sumpy/kernel.py index aebaa9fe76da1088f33dc38b3b950591732d3287..3936c518246aa8f14e54dba2f0b33d5361da04c6 100644 --- a/sumpy/kernel.py +++ b/sumpy/kernel.py @@ -93,6 +93,22 @@ class KernelArgument(object): def name(self): return self.loopy_arg.name + def __eq__(self, other): + if id(self) == id(other): + return True + if not type(self) == KernelArgument: + return NotImplemented + if not type(other) == KernelArgument: + return NotImplemented + return self.loopy_arg == other.loopy_arg + + def __ne__(self, other): + # Needed for python2 + return not self == other + + def __hash__(self): + return (type(self), self.loopy_arg) + # {{{ basic kernel interface @@ -853,17 +869,6 @@ class DirectionalDerivative(DerivativeBase): self.inner_kernel, self.dir_vec_name) - def get_source_args(self): - return [ - KernelArgument( - loopy_arg=lp.GlobalArg( - self.dir_vec_name, - None, - shape=(self.dim, "nsources"), - dim_tags="sep,C"), - ) - ] + self.inner_kernel.get_source_args() - class DirectionalTargetDerivative(DirectionalDerivative): directional_kind = "tgt" @@ -892,6 +897,17 @@ class DirectionalTargetDerivative(DirectionalDerivative): return sum(dir_vec[axis]*expr.diff(bvec[axis]) for axis in range(dim)) + def get_source_args(self): + return [ + KernelArgument( + loopy_arg=lp.GlobalArg( + self.dir_vec_name, + None, + shape=(self.dim, "ntargets"), + dim_tags="sep,C"), + ) + ] + self.inner_kernel.get_source_args() + mapper_method = "map_directional_target_derivative" @@ -923,6 +939,17 @@ class DirectionalSourceDerivative(DirectionalDerivative): return sum(-dir_vec[axis]*expr.diff(avec[axis]) for axis in range(dimensions)) + def get_source_args(self): + return [ + KernelArgument( + loopy_arg=lp.GlobalArg( + self.dir_vec_name, + None, + shape=(self.dim, "nsources"), + dim_tags="sep,C"), + ) + ] + self.inner_kernel.get_source_args() + mapper_method = "map_directional_source_derivative" # }}} diff --git a/sumpy/tools.py b/sumpy/tools.py index 540208f342e429326b7571168b754f02c0041d79..4d1098429d1c9db43af4a6e26fbb63509438a375 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -205,12 +205,18 @@ def vector_from_device(queue, vec): return with_object_array_or_scalar(from_dev, vec) +def _merge_kernel_arguments(dictionary, arg): + # Check for strict equality until there's a usecase + if dictionary.setdefault(arg.name, arg) != arg: + msg = "Merging two different kernel arguments {} and {} with the same name" + raise ValueError(msg.format(arg.loopy_arg, dictionary[arg].loopy_arg)) + + def gather_arguments(kernel_likes): result = {} for knl in kernel_likes: for arg in knl.get_args(): - result[arg.name] = arg - # FIXME: possibly check that arguments match before overwriting + _merge_kernel_arguments(result, arg) return sorted(six.itervalues(result), key=lambda arg: arg.name) @@ -219,8 +225,7 @@ def gather_source_arguments(kernel_likes): result = {} for knl in kernel_likes: for arg in knl.get_args() + knl.get_source_args(): - result[arg.name] = arg - # FIXME: possibly check that arguments match before overwriting + _merge_kernel_arguments(result, arg) return sorted(six.itervalues(result), key=lambda arg: arg.name)