diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 0ebe90fbca0d31c05eaee64321e2b73709292331..36fbb49f4bb77c959877fb0bd21e1de6fb49c74b 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -594,6 +594,10 @@ def get_simple_strides(bset, key_by="name"): """ result = {} + comp_div_set_pieces = convexify(bset.compute_divs()).get_basic_sets() + assert len(comp_div_set_pieces) == 1 + bset, = comp_div_set_pieces + lspace = bset.get_local_space() for idiv in range(lspace.dim(dim_type.div)): div = lspace.get_div(idiv) diff --git a/loopy/statistics.py b/loopy/statistics.py index fde8643bf92b7ad56bb47975fa7ede1bda9b399c..cb15eb55498bcafe4ae537747e387e47ddbd8254 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -66,16 +66,17 @@ class ToCountMap(object): """ - def __init__(self, init_dict=None): + def __init__(self, init_dict=None, val_type=isl.PwQPolynomial): if init_dict is None: init_dict = {} self.count_map = init_dict + self.val_type = val_type def __add__(self, other): result = self.count_map.copy() for k, v in six.iteritems(other.count_map): result[k] = self.count_map.get(k, 0) + v - return ToCountMap(result) + return ToCountMap(result, self.val_type) def __radd__(self, other): if other != 0: @@ -101,7 +102,11 @@ class ToCountMap(object): try: return self.count_map[index] except KeyError: - return isl.PwQPolynomial('{ 0 }') + #TODO what is the best way to handle this? + if self.val_type is isl.PwQPolynomial: + return isl.PwQPolynomial('{ 0 }') + else: + return 0 def __setitem__(self, index, value): self.count_map[index] = value @@ -112,6 +117,9 @@ class ToCountMap(object): def __len__(self): return len(self.count_map) + def get(self, key, default=None): + return self.count_map.get(key, default) + def items(self): return self.count_map.items() @@ -122,7 +130,7 @@ class ToCountMap(object): return self.count_map.pop(item) def copy(self): - return ToCountMap(dict(self.count_map)) + return ToCountMap(dict(self.count_map), self.val_type) def filter_by(self, **kwargs): """Remove items without specified key fields. @@ -149,7 +157,7 @@ class ToCountMap(object): """ - result_map = ToCountMap() + result_map = ToCountMap(val_type=self.val_type) from loopy.types import to_loopy_type if 'dtype' in kwargs.keys(): @@ -197,7 +205,7 @@ class ToCountMap(object): """ - result_map = ToCountMap() + result_map = ToCountMap(val_type=self.val_type) # for each item in self.count_map, call func on the key for self_key, self_val in self.items(): @@ -252,7 +260,7 @@ class ToCountMap(object): """ - result_map = ToCountMap() + result_map = ToCountMap(val_type=self.val_type) # make sure all item keys have same type if self.count_map: @@ -315,23 +323,36 @@ class ToCountMap(object): bytes_processed = int(key.dtype.itemsize) * val result[key] = bytes_processed + #TODO again, is this okay? + result.val_type = int + return result def sum(self): """Add all counts in ToCountMap. - :return: A :class:`islpy.PwQPolynomial` containing the sum of counts. + :return: A :class:`islpy.PwQPolynomial` or :class:`int` containing the sum of + counts. """ - total = isl.PwQPolynomial('{ 0 }') + + if self.val_type is isl.PwQPolynomial: + total = isl.PwQPolynomial('{ 0 }') + else: + total = 0 + for k, v in self.items(): - if not isinstance(v, isl.PwQPolynomial): - raise ValueError("ToCountMap: sum() encountered type {0} but " - "may only be used on PwQPolynomials." - .format(type(v))) total += v return total + #TODO test and document + def eval(self, params): + result = self.copy() + for key, val in self.items(): + result[key] = val.eval_with_dict(params) + result.val_type = int + return result + def eval_and_sum(self, params): """Add all counts in :class:`ToCountMap` and evaluate with provided parameter dict. diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index a19e06ecdf7c9966501ebb9600ea4e01614363f4..6077332c4fc4322ac7ffb02ade4a0e24c7066245 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -681,12 +681,18 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, dt, dim_idx = var_dict[primed_non1_saxis_names[i]] mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis) + def add_assumptions(d): + assumption_non_param = isl.BasicSet.from_params(kernel.assumptions) + assumptions, domain = isl.align_two(assumption_non_param, d) + return assumptions & domain + # {{{ check that we got the desired domain - check_domain = check_domain.project_out_except( - primed_non1_saxis_names, [isl.dim_type.set]) + check_domain = add_assumptions( + check_domain.project_out_except( + primed_non1_saxis_names, [isl.dim_type.set])) - mod_check_domain = mod_domain + mod_check_domain = add_assumptions(mod_domain) # re-add the prime from the new variable var_dict = mod_check_domain.get_var_dict(isl.dim_type.set) @@ -716,10 +722,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # project out the new names from the modified domain orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set)) - mod_check_domain = mod_domain.project_out_except( - orig_domain_inames, [isl.dim_type.set]) + mod_check_domain = add_assumptions( + mod_domain.project_out_except( + orig_domain_inames, [isl.dim_type.set])) - check_domain = domch.domain + check_domain = add_assumptions(domch.domain) mod_check_domain, check_domain = isl.align_two( mod_check_domain, check_domain)