Skip to content
Snippets Groups Projects
Commit 0c903968 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Merge branch 'targetderivative' into 'master'

Fix get_source_args for DirectionalTargetDerivative

See merge request !119
parents fa84c3a9 8b759635
No related branches found
No related tags found
1 merge request!119Fix get_source_args for DirectionalTargetDerivative
Pipeline #20217 passed with warnings
......@@ -93,6 +93,22 @@ class KernelArgument(object):
def name(self):
return self.loopy_arg.name
def __eq__(self, other):
if id(self) == id(other):
return True
if not type(self) == KernelArgument:
return NotImplemented
if not type(other) == KernelArgument:
return NotImplemented
return self.loopy_arg == other.loopy_arg
def __ne__(self, other):
# Needed for python2
return not self == other
def __hash__(self):
return (type(self), self.loopy_arg)
# {{{ basic kernel interface
......@@ -853,17 +869,6 @@ class DirectionalDerivative(DerivativeBase):
self.inner_kernel,
self.dir_vec_name)
def get_source_args(self):
return [
KernelArgument(
loopy_arg=lp.GlobalArg(
self.dir_vec_name,
None,
shape=(self.dim, "nsources"),
dim_tags="sep,C"),
)
] + self.inner_kernel.get_source_args()
class DirectionalTargetDerivative(DirectionalDerivative):
directional_kind = "tgt"
......@@ -892,6 +897,17 @@ class DirectionalTargetDerivative(DirectionalDerivative):
return sum(dir_vec[axis]*expr.diff(bvec[axis])
for axis in range(dim))
def get_source_args(self):
return [
KernelArgument(
loopy_arg=lp.GlobalArg(
self.dir_vec_name,
None,
shape=(self.dim, "ntargets"),
dim_tags="sep,C"),
)
] + self.inner_kernel.get_source_args()
mapper_method = "map_directional_target_derivative"
......@@ -923,6 +939,17 @@ class DirectionalSourceDerivative(DirectionalDerivative):
return sum(-dir_vec[axis]*expr.diff(avec[axis])
for axis in range(dimensions))
def get_source_args(self):
return [
KernelArgument(
loopy_arg=lp.GlobalArg(
self.dir_vec_name,
None,
shape=(self.dim, "nsources"),
dim_tags="sep,C"),
)
] + self.inner_kernel.get_source_args()
mapper_method = "map_directional_source_derivative"
# }}}
......
......@@ -205,12 +205,18 @@ def vector_from_device(queue, vec):
return with_object_array_or_scalar(from_dev, vec)
def _merge_kernel_arguments(dictionary, arg):
# Check for strict equality until there's a usecase
if dictionary.setdefault(arg.name, arg) != arg:
msg = "Merging two different kernel arguments {} and {} with the same name"
raise ValueError(msg.format(arg.loopy_arg, dictionary[arg].loopy_arg))
def gather_arguments(kernel_likes):
result = {}
for knl in kernel_likes:
for arg in knl.get_args():
result[arg.name] = arg
# FIXME: possibly check that arguments match before overwriting
_merge_kernel_arguments(result, arg)
return sorted(six.itervalues(result), key=lambda arg: arg.name)
......@@ -219,8 +225,7 @@ def gather_source_arguments(kernel_likes):
result = {}
for knl in kernel_likes:
for arg in knl.get_args() + knl.get_source_args():
result[arg.name] = arg
# FIXME: possibly check that arguments match before overwriting
_merge_kernel_arguments(result, arg)
return sorted(six.itervalues(result), key=lambda arg: arg.name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment