diff --git a/dagrt/builtins_python.py b/dagrt/builtins_python.py index 9f506cb9c521de285c5c9471c456871f69d601d6..623776ca20d2d203abf4ae478759d56f1197f58e 100644 --- a/dagrt/builtins_python.py +++ b/dagrt/builtins_python.py @@ -85,6 +85,19 @@ def builtin_matmul(a, b, a_cols, b_cols): return res_mat.reshape(-1, order="F") +def builtin_transpose(a, a_cols): + import numpy as np + if a_cols != np.floor(a_cols): + raise ValueError("transpose() argument a_cols is not an integer") + a_cols = int(a_cols) + + a_mat = a.reshape(-1, a_cols, order="F") + + res_mat = np.transpose(a_mat) + + return res_mat.reshape(-1, order="F") + + def builtin_linear_solve(a, b, a_cols, b_cols): import numpy as np if a_cols != np.floor(a_cols): @@ -130,6 +143,7 @@ builtins = { "dot_product": builtin_dot_product, "array": builtin_array, "matmul": builtin_matmul, + "transpose": builtin_transpose, "linear_solve": builtin_linear_solve, "svd": builtin_svd, "print": builtin_print, diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index 9c2397fdb0b09f91345094ab6aad2a51cf99aea9..934b06539d89688dd24b7d71899682e883509ac0 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -2167,7 +2167,7 @@ class CodeGenerator(StructuredCodeGenerator): inst.as_expression()) assignee_fortran_names = [ - self.name_manager[assignee_sym] for a in inst.assignees] + self.name_manager[assignee_sym] for assignee_sym in inst.assignees] function = self.function_registry[inst.function_id] @@ -2417,7 +2417,7 @@ UTIL_MACROS = """ if (${rows_var} * int(${cols_var}) .ne. size(${mat_array})) then write(dagrt_stderr,*) & 'size of argument ' // & - '${mat_array}' // & + '${mat_array} ' // & 'to ${func_name} ' // & 'not divisible by ' // & '${cols_var}' @@ -2480,6 +2480,31 @@ builtin_matmul = CallCode(UTIL_MACROS + """ (/${res_size}/)) """) + +builtin_transpose = CallCode(UTIL_MACROS + """ + <% + a_rows = declare_new("integer", "a_rows") + res_size = declare_new("integer", "res_size") + %> + + ${check_matrix(a, a_cols, a_rows, "transpose")} + + ${a_rows} = size(${a}) / int(${a_cols}) + ${res_size} = ${a_rows} * int(${a_cols}) + + if (allocated(${result})) then + deallocate(${result}) + endif + + allocate(${result}(0:${res_size}-1)) + + ${result} = reshape( & + transpose( & + reshape(${a}, (/${a_rows}, int(${a_cols})/))), & + (/${res_size}/)) + """) + + builtin_linear_solve = CallCode(UTIL_MACROS + """ <% a_rows = declare_new("integer", "a_rows") @@ -2543,6 +2568,74 @@ builtin_linear_solve = CallCode(UTIL_MACROS + """ """) + +builtin_svd = CallCode(UTIL_MACROS + """ + <% + sigma_size = declare_new("integer", "res_size") + a_rows = declare_new("integer", "a_rows") + + %> + + ${check_matrix(a, a_cols, a_rows, "svd")} + ${sigma_size} = min(int(${a_cols}),int(${a_rows})) + + <% + ltr = get_lapack_letter(a_kind) + + a_temp = declare_new( + kind_to_fortran(a_kind)+", dimension(:), allocatable" + , "a_temp") + work = declare_new( + kind_to_fortran(a_kind)+", dimension(:), allocatable" + , "work") + info = declare_new("integer", "info") + lwork = declare_new("integer", "lwork") + lda = declare_new("integer", "lda") + ldu = declare_new("integer", "ldu") + ldvt = declare_new("integer", "ldvt") + jobu = declare_new("character*1", "jobu") + jobvt = declare_new("character*1", "jobvt") + %> + + allocate(${a_temp}(0:size(${a})-1)) + ${jobu} = "S" + ${jobvt} = "S" + ${lda} = max(1,int(${a_rows})) + ${ldu} = int(${a_rows}) + ${ldvt} = min(int(${a_rows}),int(${a_rows})) + + ${a_temp} = ${a} + ${lwork} = max(1, & + 3*min(int(${a_rows}), & + int(${a_cols})) + max(int(${a_rows}), & + int(${a_cols})), & + 5*min(int(${a_rows}), int(${a_cols}))) + + if (allocated(${sigma})) then + deallocate(${sigma}) + endif + + allocate(${sigma}(0:${sigma_size}-1)) + allocate(${work}(0:${lwork}-1)) + allocate(${u}(0:int(${a_rows}*${a_rows})-1)) + allocate(${vt}(0:int(${a_rows}*${a_cols})-1)) + + call ${ltr}gesvd(${jobu}, ${jobvt}, & + int(${a_rows}), int(${a_cols}), ${a_temp}, ${lda}, ${sigma}, & + ${u}, ${ldu}, ${vt}, ${ldvt}, ${work}, ${lwork}, ${info}) + + if (${info}.ne.0) then + write(dagrt_stderr,*) & + 'gesvd on ${a} failed with info=', ${info} + stop + endif + + deallocate(${a_temp}) + deallocate(${work}) + + """) + + builtin_print = CallCode(UTIL_MACROS + """ write(*,*) ${arg} """) diff --git a/dagrt/function_registry.py b/dagrt/function_registry.py index 1581a07f85dd9a1a3677f67da5520511e2be21df..5082151604d98103f64128b07582622a6865cb2c 100644 --- a/dagrt/function_registry.py +++ b/dagrt/function_registry.py @@ -321,6 +321,33 @@ class _MatMul(Function): return (Array(is_real_valued),) +class _Transpose(Function): + """``transpose(a, a_cols)`` returns a 1D array containing the + matrix resulting from transposing the array a + """ + + result_names = ("result",) + identifier = "transpose" + arg_names = ("a", "a_cols") + default_dict = {} + + def get_result_kinds(self, arg_kinds, check): + a_kind, a_cols_kind = self.resolve_args(arg_kinds) + + if a_kind is None: + raise UnableToInferKind( + "transpose needs to know both arguments to infer result kind") + + if check and not isinstance(a_kind, Array): + raise TypeError("argument 'a' of 'transpose' is not an array") + if check and not isinstance(a_cols_kind, Scalar): + raise TypeError("argument 'a_cols' of 'transpose' is not a scalar") + + is_real_valued = a_kind.is_real_valued + + return (Array(is_real_valued),) + + class _LinearSolve(Function): """``linear_solve(a, b, a_cols, b_cols)`` returns a 1D array containing the matrix resulting from multiplying the matrix inverse of a by b (both interpreted @@ -354,7 +381,7 @@ class _LinearSolve(Function): class _SVD(Function): - """``linear_solve(a, a_cols)`` returns a 2D array ``u``, a 1D array ``sigma``, and + """``SVD(a, a_cols)`` returns a 2D array ``u``, a 1D array ``sigma``, and a 2D array ``vt``, representing the (reduced) SVD of ``a``. """ @@ -371,9 +398,9 @@ class _SVD(Function): "svd needs to know its argument to infer result kind") if check and not isinstance(a_kind, Array): - raise TypeError("argument 'a' of 'linear_solve' is not an array") + raise TypeError("argument 'a' of 'svd' is not an array") if check and not isinstance(a_cols_kind, Scalar): - raise TypeError("argument 'a_cols' of 'linear_solve' is not a scalar") + raise TypeError("argument 'a_cols' of 'svd' is not a scalar") is_real_valued = a_kind.is_real_valued @@ -424,6 +451,7 @@ def _make_bfr(): (_IsNaN(), "{numpy}.isnan({args})"), (_Array(), "self._builtin_array({args})"), (_MatMul(), "self._builtin_matmul({args})"), + (_Transpose(), "self._builtin_transpose({args})"), (_LinearSolve(), "self._builtin_linear_solve({args})"), (_Print(), "self._builtin_print({args})"), (_SVD(), "self._builtin_svd({args})"), @@ -448,8 +476,12 @@ def _make_bfr(): f.builtin_array) bfr = bfr.register_codegen(_MatMul.identifier, "fortran", f.builtin_matmul) + bfr = bfr.register_codegen(_Transpose.identifier, "fortran", + f.builtin_transpose) bfr = bfr.register_codegen(_LinearSolve.identifier, "fortran", f.builtin_linear_solve) + bfr = bfr.register_codegen(_SVD.identifier, "fortran", + f.builtin_svd) bfr = bfr.register_codegen(_Print.identifier, "fortran", f.builtin_print)