diff --git a/course/content.py b/course/content.py
index ec237219940664c4b54a5edf9c5c727321e31bbc..c8cdd493bc83733b111f121ba3c3d9cbaad0752c 100644
--- a/course/content.py
+++ b/course/content.py
@@ -330,10 +330,12 @@ def markup_to_html(course, repo, commit_sha, text):
         template = env.from_string(text)
         text = template.render()
 
+    from course.mdx_mathjax import MathJaxExtension
     import markdown
     return markdown.markdown(text,
         extensions=[
             LinkFixerExtension(course, commit_sha),
+            MathJaxExtension(),
             "extra",
             ],
         output_format="html5")
diff --git a/course/mdx_mathjax.py b/course/mdx_mathjax.py
new file mode 100644
index 0000000000000000000000000000000000000000..89de7904ea7e7498ccbf9449b2e27bbafcd81fd1
--- /dev/null
+++ b/course/mdx_mathjax.py
@@ -0,0 +1,32 @@
+# Downloaded from https://github.com/mayoff/python-markdown-mathjax/issues/3
+
+import markdown
+from markdown.postprocessors import Postprocessor
+
+
+class MathJaxPattern(markdown.inlinepatterns.Pattern):
+    def __init__(self):
+        markdown.inlinepatterns.Pattern.__init__(self, r'(?<!\\)(\$\$?)(.+?)\2')
+
+    def handleMatch(self, m):
+        node = markdown.util.etree.Element('mathjax')
+        node.text = markdown.util.AtomicString(m.group(2) + m.group(3) + m.group(2))
+        return node
+
+
+class MathJaxPostprocessor(Postprocessor):
+    def run(self, text):
+        text = text.replace('<mathjax>', '')
+        text = text.replace('</mathjax>', '')
+        return text
+
+
+class MathJaxExtension(markdown.Extension):
+    def extendMarkdown(self, md, md_globals):
+        # Needs to come before escape matching because \ is pretty important in LaTeX
+        md.inlinePatterns.add('mathjax', MathJaxPattern(), '<escape')
+        md.postprocessors['mathjax'] = MathJaxPostprocessor(md)
+
+
+def makeExtension(configs=None):
+    return MathJaxExtension(configs)