diff --git a/loopy/check.py b/loopy/check.py index 01a6e52c268a5c4e76f94653490d05c345e1a7d0..354c6bf3542e17367d602702b046180a5ccc85db 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -1,7 +1,4 @@ -from __future__ import division -from __future__ import absolute_import -import six -from six.moves import range +from __future__ import absolute_import, division, print_function __copyright__ = "Copyright (C) 2012 Andreas Kloeckner" @@ -25,6 +22,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six +from six.moves import range from islpy import dim_type import islpy as isl @@ -484,10 +483,14 @@ def check_implemented_domains(kernel, implemented_domains, code=None): .project_out_except(insn_inames, [dim_type.set])) insn_domain = kernel.get_inames_domain(insn_inames) + insn_parameters = frozenset(insn_domain.get_var_names(dim_type.param)) assumptions, insn_domain = align_two(assumption_non_param, insn_domain) desired_domain = ((insn_domain & assumptions) - .project_out_except(insn_inames, [dim_type.set])) + .project_out_except(insn_inames, [dim_type.set]) + .project_out_except(insn_parameters, [dim_type.param])) + insn_impl_domain = (insn_impl_domain + .project_out_except(insn_parameters, [dim_type.param])) insn_impl_domain, desired_domain = align_two( insn_impl_domain, desired_domain)