Skip to content
Snippets Groups Projects
Commit 53d0aacf authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Implement loopy define blocks

parent 851debc1
No related branches found
No related tags found
No related merge requests found
!$loopy begin define
! define("factor 4.0")
! define("real_type real*8")
!$loopy end define
subroutine fill(out, a, n)
implicit none
real*8 a, out(n)
real_type a, out(n)
integer n, i
do i = 1, n
out(i) = a
end do
do i = 1, n
out(i) = out(i) * 2
out(i) = out(i) * factor
end do
end
......
......@@ -25,10 +25,56 @@ THE SOFTWARE.
from loopy.diagnostic import LoopyError
def _extract_define_lines(source):
lines = source.split("\n")
import re
comment_re = re.compile(r"^\s*\!(.*)$")
remaining_lines = []
define_lines = []
in_define_code = False
for l in lines:
comment_match = comment_re.match(l)
if comment_match is None:
if in_define_code:
raise LoopyError("non-comment source line in define block")
remaining_lines.append(l)
continue
cmt = comment_match.group(1)
cmt_stripped = cmt.strip()
if cmt_stripped == "$loopy begin define":
if in_define_code:
raise LoopyError("can't enter transform code twice")
in_define_code = True
elif cmt_stripped == "$loopy end define":
if not in_define_code:
raise LoopyError("can't leave transform code twice")
in_define_code = False
elif in_define_code:
define_lines.append(cmt)
else:
remaining_lines.append(l)
return "\n".join(remaining_lines), "\n".join(define_lines)
def f2loopy(source, free_form=True, strict=True,
pre_transform_code=None, pre_transform_code_context=None,
use_c_preprocessor=False,
pre_transform_code=None, transform_code_context=None,
use_c_preprocessor=False, preprocessor_defines=None,
file_name="<floopy code>"):
"""
:arg preprocessor_defines: a list of strings as they might occur after a
C-style ``#define`` directive, for example ``deg2rad(x) (x/180d0 * 3.14d0)``.
"""
if use_c_preprocessor:
try:
import ply.lex as lex
......@@ -40,6 +86,28 @@ def f2loopy(source, free_form=True, strict=True,
from ply.cpp import Preprocessor
p = Preprocessor(lexer)
if preprocessor_defines:
for d in preprocessor_defines:
p.define(d)
source, define_code = _extract_define_lines(source)
if define_code is not None:
from loopy.tools import remove_common_indentation
define_code = remove_common_indentation(
define_code,
require_leading_newline=False)
def_dict = {}
def_dict["define"] = p.define
if pre_transform_code is not None:
def_dict["_MODULE_SOURCE_CODE"] = pre_transform_code
exec(compile(pre_transform_code,
"<loopy pre-transform code>", "exec"), def_dict)
def_dict["_MODULE_SOURCE_CODE"] = define_code
exec(compile(define_code, "<loopy defines>", "exec"), def_dict)
p.parse(source, file_name)
tokens = []
......@@ -65,6 +133,6 @@ def f2loopy(source, free_form=True, strict=True,
f2loopy(tree)
return f2loopy.make_kernels(pre_transform_code=pre_transform_code,
pre_transform_code_context=pre_transform_code_context)
transform_code_context=transform_code_context)
# vim: foldmethod=marker
......@@ -194,24 +194,6 @@ class Scope(object):
# }}}
def remove_common_indentation(lines):
while lines and lines[0].strip() == "":
lines.pop(0)
while lines and lines[-1].strip() == "":
lines.pop(-1)
if lines:
base_indent = 0
while lines[0][base_indent] in " \t":
base_indent += 1
for line in lines[1:]:
if line[:base_indent].strip():
raise ValueError("inconsistent indentation")
return "\n".join(line[base_indent:] for line in lines)
# {{{ translator
class F2LoopyTranslator(FTreeWalkerBase):
......@@ -641,15 +623,15 @@ class F2LoopyTranslator(FTreeWalkerBase):
# }}}
def make_kernels(self, pre_transform_code=None, pre_transform_code_context=None):
def make_kernels(self, pre_transform_code=None, transform_code_context=None):
kernel_names = [
sub.subprogram_name
for sub in self.kernels]
if pre_transform_code_context is None:
if transform_code_context is None:
proc_dict = {}
else:
proc_dict = pre_transform_code_context.copy()
proc_dict = transform_code_context.copy()
proc_dict["lp"] = lp
proc_dict["np"] = np
......@@ -707,13 +689,15 @@ class F2LoopyTranslator(FTreeWalkerBase):
proc_dict[sub.subprogram_name] = lp.fold_constants(knl)
from loopy.tools import remove_common_indentation
transform_code = remove_common_indentation(
self.transform_code_lines)
"\n".join(self.transform_code_lines),
require_leading_newline=False)
if pre_transform_code is not None:
proc_dict["_MODULE_SOURCE_CODE"] = pre_transform_code
exec(compile(pre_transform_code,
"<loopy transforms>", "exec"), proc_dict)
"<loopy pre-transform code>", "exec"), proc_dict)
proc_dict["_MODULE_SOURCE_CODE"] = transform_code
exec(compile(transform_code,
......
......@@ -164,14 +164,14 @@ class PicklableDtype(object):
# {{{ remove common indentation
def remove_common_indentation(code):
def remove_common_indentation(code, require_leading_newline=True):
if "\n" not in code:
return code
# accommodate pyopencl-ish syntax highlighting
code = code.lstrip("//CL//")
if not code.startswith("\n"):
if require_leading_newline and not code.startswith("\n"):
return code
lines = code.split("\n")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment