diff armadillo-2.4.4/include/armadillo_bits/gemm_mixed.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/armadillo-2.4.4/include/armadillo_bits/gemm_mixed.hpp	Wed Apr 11 09:27:06 2012 +0100
@@ -0,0 +1,453 @@
+// Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
+// Copyright (C) 2008-2011 Conrad Sanderson
+// 
+// This file is part of the Armadillo C++ library.
+// It is provided without any warranty of fitness
+// for any purpose. You can redistribute this file
+// and/or modify it under the terms of the GNU
+// Lesser General Public License (LGPL) as published
+// by the Free Software Foundation, either version 3
+// of the License or (at your option) any later version.
+// (see http://www.opensource.org/licenses for more info)
+
+
+//! \addtogroup gemm_mixed
+//! @{
+
+
+
+//! \brief
+//! Matrix multplication where the matrices have differing element types.
+//! Uses caching for speedup.
+//! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
+
+template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
+class gemm_mixed_large
+  {
+  public:
+  
+  template<typename out_eT, typename in_eT1, typename in_eT2>
+  arma_hot
+  inline
+  static
+  void
+  apply
+    (
+          Mat<out_eT>& C,
+    const Mat<in_eT1>& A,
+    const Mat<in_eT2>& B,
+    const out_eT alpha = out_eT(1),
+    const out_eT beta  = out_eT(0)
+    )
+    {
+    arma_extra_debug_sigprint();
+    
+    const uword A_n_rows = A.n_rows;
+    const uword A_n_cols = A.n_cols;
+    
+    const uword B_n_rows = B.n_rows;
+    const uword B_n_cols = B.n_cols;
+    
+    if( (do_trans_A == false) && (do_trans_B == false) )
+      {
+      podarray<in_eT1> tmp(A_n_cols);
+      in_eT1* A_rowdata = tmp.memptr();
+      
+      for(uword row_A=0; row_A < A_n_rows; ++row_A)
+        {
+        tmp.copy_row(A, row_A);
+        
+        for(uword col_B=0; col_B < B_n_cols; ++col_B)
+          {
+          const in_eT2* B_coldata = B.colptr(col_B);
+          
+          out_eT acc = out_eT(0);
+          for(uword i=0; i < B_n_rows; ++i)
+            {
+            acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
+            }
+        
+          if( (use_alpha == false) && (use_beta == false) )
+            {
+            C.at(row_A,col_B) = acc;
+            }
+          else
+          if( (use_alpha == true) && (use_beta == false) )
+            {
+            C.at(row_A,col_B) = alpha * acc;
+            }
+          else
+          if( (use_alpha == false) && (use_beta == true) )
+            {
+            C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
+            }
+          else
+          if( (use_alpha == true) && (use_beta == true) )
+            {
+            C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
+            }
+          
+          }
+        }
+      }
+    else
+    if( (do_trans_A == true) && (do_trans_B == false) )
+      {
+      for(uword col_A=0; col_A < A_n_cols; ++col_A)
+        {
+        // col_A is interpreted as row_A when storing the results in matrix C
+        
+        const in_eT1* A_coldata = A.colptr(col_A);
+        
+        for(uword col_B=0; col_B < B_n_cols; ++col_B)
+          {
+          const in_eT2* B_coldata = B.colptr(col_B);
+          
+          out_eT acc = out_eT(0);
+          for(uword i=0; i < B_n_rows; ++i)
+            {
+            acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
+            }
+        
+          if( (use_alpha == false) && (use_beta == false) )
+            {
+            C.at(col_A,col_B) = acc;
+            }
+          else
+          if( (use_alpha == true) && (use_beta == false) )
+            {
+            C.at(col_A,col_B) = alpha * acc;
+            }
+          else
+          if( (use_alpha == false) && (use_beta == true) )
+            {
+            C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
+            }
+          else
+          if( (use_alpha == true) && (use_beta == true) )
+            {
+            C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
+            }
+          
+          }
+        }
+      }
+    else
+    if( (do_trans_A == false) && (do_trans_B == true) )
+      {
+      Mat<in_eT2> B_tmp;
+      
+      op_strans::apply_noalias(B_tmp, B);
+      
+      gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
+      }
+    else
+    if( (do_trans_A == true) && (do_trans_B == true) )
+      {
+      // mat B_tmp = trans(B);
+      // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
+      
+      
+      // By using the trans(A)*trans(B) = trans(B*A) equivalency,
+      // transpose operations are not needed
+      
+      podarray<in_eT2> tmp(B_n_cols);
+      in_eT2* B_rowdata = tmp.memptr();
+      
+      for(uword row_B=0; row_B < B_n_rows; ++row_B)
+        {
+        tmp.copy_row(B, row_B);
+        
+        for(uword col_A=0; col_A < A_n_cols; ++col_A)
+          {
+          const in_eT1* A_coldata = A.colptr(col_A);
+          
+          out_eT acc = out_eT(0);
+          for(uword i=0; i < A_n_rows; ++i)
+            {
+            acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
+            }
+        
+          if( (use_alpha == false) && (use_beta == false) )
+            {
+            C.at(col_A,row_B) = acc;
+            }
+          else
+          if( (use_alpha == true) && (use_beta == false) )
+            {
+            C.at(col_A,row_B) = alpha * acc;
+            }
+          else
+          if( (use_alpha == false) && (use_beta == true) )
+            {
+            C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
+            }
+          else
+          if( (use_alpha == true) && (use_beta == true) )
+            {
+            C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
+            }
+          
+          }
+        }
+      
+      }
+    }
+    
+  };
+
+
+
+//! Matrix multplication where the matrices have different element types.
+//! Simple version (no caching).
+//! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
+template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
+class gemm_mixed_small
+  {
+  public:
+  
+  template<typename out_eT, typename in_eT1, typename in_eT2>
+  arma_hot
+  inline
+  static
+  void
+  apply
+    (
+          Mat<out_eT>& C,
+    const Mat<in_eT1>& A,
+    const Mat<in_eT2>& B,
+    const out_eT alpha = out_eT(1),
+    const out_eT beta  = out_eT(0)
+    )
+    {
+    arma_extra_debug_sigprint();
+    
+    const uword A_n_rows = A.n_rows;
+    const uword A_n_cols = A.n_cols;
+    
+    const uword B_n_rows = B.n_rows;
+    const uword B_n_cols = B.n_cols;
+    
+    if( (do_trans_A == false) && (do_trans_B == false) )
+      {
+      for(uword row_A = 0; row_A < A_n_rows; ++row_A)
+        {
+        for(uword col_B = 0; col_B < B_n_cols; ++col_B)
+          {
+          const in_eT2* B_coldata = B.colptr(col_B);
+          
+          out_eT acc = out_eT(0);
+          for(uword i = 0; i < B_n_rows; ++i)
+            {
+            const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
+            const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
+            acc += val1 * val2;
+            //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
+            }
+          
+          if( (use_alpha == false) && (use_beta == false) )
+            {
+            C.at(row_A,col_B) = acc;
+            }
+          else
+          if( (use_alpha == true) && (use_beta == false) )
+            {
+            C.at(row_A,col_B) = alpha * acc;
+            }
+          else
+          if( (use_alpha == false) && (use_beta == true) )
+            {
+            C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
+            }
+          else
+          if( (use_alpha == true) && (use_beta == true) )
+            {
+            C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
+            }
+          }
+        }
+      }
+    else
+    if( (do_trans_A == true) && (do_trans_B == false) )
+      {
+      for(uword col_A=0; col_A < A_n_cols; ++col_A)
+        {
+        // col_A is interpreted as row_A when storing the results in matrix C
+        
+        const in_eT1* A_coldata = A.colptr(col_A);
+        
+        for(uword col_B=0; col_B < B_n_cols; ++col_B)
+          {
+          const in_eT2* B_coldata = B.colptr(col_B);
+          
+          out_eT acc = out_eT(0);
+          for(uword i=0; i < B_n_rows; ++i)
+            {
+            acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
+            }
+        
+          if( (use_alpha == false) && (use_beta == false) )
+            {
+            C.at(col_A,col_B) = acc;
+            }
+          else
+          if( (use_alpha == true) && (use_beta == false) )
+            {
+            C.at(col_A,col_B) = alpha * acc;
+            }
+          else
+          if( (use_alpha == false) && (use_beta == true) )
+            {
+            C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
+            }
+          else
+          if( (use_alpha == true) && (use_beta == true) )
+            {
+            C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
+            }
+          
+          }
+        }
+      }
+    else
+    if( (do_trans_A == false) && (do_trans_B == true) )
+      {
+      for(uword row_A = 0; row_A < A_n_rows; ++row_A)
+        {
+        for(uword row_B = 0; row_B < B_n_rows; ++row_B)
+          {
+          out_eT acc = out_eT(0);
+          for(uword i = 0; i < B_n_cols; ++i)
+            {
+            acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
+            }
+          
+          if( (use_alpha == false) && (use_beta == false) )
+            {
+            C.at(row_A,row_B) = acc;
+            }
+          else
+          if( (use_alpha == true) && (use_beta == false) )
+            {
+            C.at(row_A,row_B) = alpha * acc;
+            }
+          else
+          if( (use_alpha == false) && (use_beta == true) )
+            {
+            C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
+            }
+          else
+          if( (use_alpha == true) && (use_beta == true) )
+            {
+            C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
+            }
+          }
+        }
+      }
+    else
+    if( (do_trans_A == true) && (do_trans_B == true) )
+      {
+      for(uword row_B=0; row_B < B_n_rows; ++row_B)
+        {
+        
+        for(uword col_A=0; col_A < A_n_cols; ++col_A)
+          {
+          const in_eT1* A_coldata = A.colptr(col_A);
+          
+          out_eT acc = out_eT(0);
+          for(uword i=0; i < A_n_rows; ++i)
+            {
+            acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
+            }
+        
+          if( (use_alpha == false) && (use_beta == false) )
+            {
+            C.at(col_A,row_B) = acc;
+            }
+          else
+          if( (use_alpha == true) && (use_beta == false) )
+            {
+            C.at(col_A,row_B) = alpha * acc;
+            }
+          else
+          if( (use_alpha == false) && (use_beta == true) )
+            {
+            C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
+            }
+          else
+          if( (use_alpha == true) && (use_beta == true) )
+            {
+            C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
+            }
+          
+          }
+        }
+      
+      }
+    }
+    
+  };
+
+
+
+
+
+//! \brief
+//! Matrix multplication where the matrices have differing element types.
+
+template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
+class gemm_mixed
+  {
+  public:
+  
+  //! immediate multiplication of matrices A and B, storing the result in C
+  template<typename out_eT, typename in_eT1, typename in_eT2>
+  inline
+  static
+  void
+  apply
+    (
+          Mat<out_eT>& C,
+    const Mat<in_eT1>& A,
+    const Mat<in_eT2>& B,
+    const out_eT alpha = out_eT(1),
+    const out_eT beta  = out_eT(0)
+    )
+    {
+    arma_extra_debug_sigprint();
+    
+    Mat<in_eT1> tmp_A;
+    Mat<in_eT2> tmp_B;
+    
+    const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
+    const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
+    
+    if(do_trans_A)
+      {
+      op_htrans::apply_noalias(tmp_A, A);
+      }
+    
+    if(do_trans_B)
+      {
+      op_htrans::apply_noalias(tmp_B, B);
+      }
+     
+    const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
+    const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
+    
+    if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
+      {
+      gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
+      }
+    else
+      {
+      gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
+      }
+    }
+  
+  
+  };
+
+
+
+//! @}