diff --git a/pytools/lex.py b/pytools/lex.py
index dfef0208144bbbea22e792c03525586a62d96387..2d65ce9fa388efdcc31120a1124d178a77a2107b 100644
--- a/pytools/lex.py
+++ b/pytools/lex.py
@@ -34,7 +34,7 @@ class ParseError(RuntimeError):
                            self.string[self.Token[2]:self.Token[2]+20])
 
 
-class RE:
+class RE(object):
     def __init__(self, s, flags=0):
         self.Content = s
         self.RE = re.compile(s, flags)
@@ -43,7 +43,7 @@ class RE:
         return "RE(%s)" % self.Content
 
 
-def lex(lex_table, s, debug=False):
+def lex(lex_table, s, debug=False, match_objects=False):
     rule_dict = dict(lex_table)
 
     def matches_rule(rule, s, start):
@@ -52,27 +52,28 @@ def lex(lex_table, s, debug=False):
         if isinstance(rule, tuple):
             if rule[0] == "|":
                 for subrule in rule[1:]:
-                    length = matches_rule(subrule, s, start)
+                    length, match_obj = matches_rule(
+                            subrule, s, start)
                     if length:
-                        return length
+                        return length, match_obj
             else:
                 my_match_length = 0
                 for subrule in rule:
-                    length = matches_rule(subrule, s, start)
+                    length, _ = matches_rule(subrule, s, start)
                     if length:
                         my_match_length += length
                         start += length
                     else:
-                        return 0
-                return my_match_length
+                        return 0, None
+                return my_match_length, None
         elif isinstance(rule, basestring):
             return matches_rule(rule_dict[rule], s, start)
         elif isinstance(rule, RE):
             match_obj = rule.RE.match(s, start)
             if match_obj:
-                return match_obj.end()-start
+                return match_obj.end()-start, match_obj
             else:
-                return 0
+                return 0, None
         else:
             raise RuleError(rule)
 
@@ -81,9 +82,12 @@ def lex(lex_table, s, debug=False):
     while i < len(s):
         rule_matched = False
         for name, rule in lex_table:
-            length = matches_rule(rule, s, i)
+            length, match_obj = matches_rule(rule, s, i)
             if length:
-                result.append((name, s[i:i+length], i))
+                if match_objects:
+                    result.append((name, s[i:i+length], i, match_obj))
+                else:
+                    result.append((name, s[i:i+length], i))
                 i += length
                 rule_matched = True
                 break
@@ -113,6 +117,9 @@ class LexIterator(object):
     def next_str(self, i=0):
         return self.lexed[self.index + i][1]
 
+    def next_match_obj(self):
+        return self.lexed[self.index][3]
+
     def next_str_and_advance(self):
         result = self.next_str()
         self.advance()