diff --git a/loopy/kernel.py b/loopy/kernel.py index 5add79c782e12ca14ca223903c5f759f2398e4c9..d1eb4edddd77f809b84bc6e4a4aa0cf78d1b826f 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -902,13 +902,23 @@ class LoopKernel(Record): "instance or a parseable string. got '%s' instead." % type(insn)) - for insn in expand_defines(insn, defines, single_valued=False): - parse_insn(insn) + for insn in insn.split("\n"): + insn = insn.strip() + + if not insn: + continue + if insn.startswith("#"): + continue + + for sub_insn in expand_defines(insn, defines, single_valued=False): + parse_insn(sub_insn) parsed_instructions = [] substitutions = substitutions.copy() + if isinstance(instructions, str): + instructions = [instructions] for insn in instructions: # must construct list one-by-one to facilitate unique id generation parse_if_necessary(insn) @@ -979,13 +989,15 @@ class LoopKernel(Record): processed_args = [] for arg in args: - if isinstance(arg, _ShapedArg): - if arg.shape is not None: - arg = arg.copy(shape=expand_defines_in_expr(arg.shape, defines)) - if arg.strides is not None: - arg = arg.copy(strides=expand_defines_in_expr(arg.strides, defines)) - - processed_args.append(arg) + for name in arg.name.split(","): + new_arg = arg.copy(name=name) + if isinstance(arg, _ShapedArg): + if arg.shape is not None: + new_arg = new_arg.copy(shape=expand_defines_in_expr(arg.shape, defines)) + if arg.strides is not None: + new_arg = new_arg.copy(strides=expand_defines_in_expr(arg.strides, defines)) + + processed_args.append(new_arg) # }}} diff --git a/test/test_linalg.py b/test/test_linalg.py index 84878d97221fd3b062bd6d7393f216e8edae7b82..b6087f4429534d11748a6b307b1ae14ed237385e 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -273,9 +273,10 @@ def test_variable_size_matrix_mul(ctx_factory): def test_rank_one(ctx_factory): dtype = np.float32 ctx = ctx_factory() - order = "C" + order = "F" - n = int(get_suitable_size(ctx)**(2.7/2)) + #n = int(get_suitable_size(ctx)**(2.7/2)) + n = 16**3 knl = lp.make_kernel(ctx.devices[0], "[n] -> {[i,j]: 0<=i,j