view armadillo-2.4.4/include/armadillo_bits/op_reshape_meat.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
line wrap: on
line source
// Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
// Copyright (C) 2008-2011 Conrad Sanderson
// 
// This file is part of the Armadillo C++ library.
// It is provided without any warranty of fitness
// for any purpose. You can redistribute this file
// and/or modify it under the terms of the GNU
// Lesser General Public License (LGPL) as published
// by the Free Software Foundation, either version 3
// of the License or (at your option) any later version.
// (see http://www.opensource.org/licenses for more info)



//! \addtogroup op_reshape
//! @{



template<typename T1>
inline
void
op_reshape::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_reshape>& in)
  {
  arma_extra_debug_sigprint();
  
  typedef typename T1::elem_type eT;
  
  const unwrap<T1>   tmp(in.m);
  const Mat<eT>& A = tmp.M;
  
  const uword in_n_rows = in.aux_uword_a;
  const uword in_n_cols = in.aux_uword_b;
  const uword in_dim    = in.aux_uword_c;
  
  const uword in_n_elem = in_n_rows * in_n_cols;
  
  if(A.n_elem == in_n_elem)
    {
    if(in_dim == 0)
      {
      if(&out != &A)
        {
        out.set_size(in_n_rows, in_n_cols);
        arrayops::copy( out.memptr(), A.memptr(), out.n_elem );
        }
      else  // &out == &A, i.e. inplace resize
        {
        const bool same_size = ( (out.n_rows == in_n_rows) && (out.n_cols == in_n_cols) );
        
        if(same_size == false)
          {
          arma_debug_check
            (
            (out.mem_state == 3),
            "reshape(): size can't be changed as template based size specification is in use"
            );
          
          access::rw(out.n_rows) = in_n_rows;
          access::rw(out.n_cols) = in_n_cols;
          }
        }
      }
    else
      {
      unwrap_check< Mat<eT> > tmp(A, out);
      const Mat<eT>& B      = tmp.M;
      
      out.set_size(in_n_rows, in_n_cols);
      
      eT* out_mem = out.memptr();
      uword i = 0;
      
      const uword B_n_rows = B.n_rows;
      const uword B_n_cols = B.n_cols;
      
      for(uword row=0; row<B_n_rows; ++row)
        {
        for(uword col=0; col<B_n_cols; ++col)
          {
          out_mem[i] = B.at(row,col);
          ++i;
          }
        }
        
      }
    }
  else
    {
    const unwrap_check< Mat<eT> > tmp(A, out);
    const Mat<eT>& B            = tmp.M;
    
    const uword n_elem_to_copy = (std::min)(B.n_elem, in_n_elem);
    
    out.set_size(in_n_rows, in_n_cols);
    
    eT* out_mem = out.memptr();
    
    if(in_dim == 0)
      {
      arrayops::copy( out_mem, B.memptr(), n_elem_to_copy );
      }
    else
      {
      uword row = 0;
      uword col = 0;
      
      const uword B_n_cols = B.n_cols;
      
      for(uword i=0; i<n_elem_to_copy; ++i)
        {
        out_mem[i] = B.at(row,col);
        
        ++col;
        
        if(col >= B_n_cols)
          {
          col = 0;
          ++row;
          }
        }
      }
    
    for(uword i=n_elem_to_copy; i<in_n_elem; ++i)
      {
      out_mem[i] = eT(0);
      }
    
    }
  }



template<typename T1>
inline
void
op_reshape::apply(Cube<typename T1::elem_type>& out, const OpCube<T1,op_reshape>& in)
  {
  arma_extra_debug_sigprint();
  
  typedef typename T1::elem_type eT;
  
  const unwrap_cube<T1> tmp(in.m);
  const Cube<eT>& A   = tmp.M;
  
  const uword in_n_rows   = in.aux_uword_a;
  const uword in_n_cols   = in.aux_uword_b;
  const uword in_n_slices = in.aux_uword_c;
  const uword in_dim      = in.aux_uword_d;
  
  const uword in_n_elem = in_n_rows * in_n_cols * in_n_slices;
  
  if(A.n_elem == in_n_elem)
    {
    if(in_dim == 0)
      {
      if(&out != &A)
        {
        out.set_size(in_n_rows, in_n_cols, in_n_slices);
        arrayops::copy( out.memptr(), A.memptr(), out.n_elem );
        }
      else  // &out == &A, i.e. inplace resize
        {
        const bool same_size = ( (out.n_rows == in_n_rows) && (out.n_cols == in_n_cols) && (out.n_slices == in_n_slices) );
        
        if(same_size == false)
          {
          arma_debug_check
            (
            (out.mem_state == 3),
            "reshape(): size can't be changed as template based size specification is in use"
            );
          
          out.delete_mat();
          
          access::rw(out.n_rows)       = in_n_rows;
          access::rw(out.n_cols)       = in_n_cols;
          access::rw(out.n_elem_slice) = in_n_rows * in_n_cols;
          access::rw(out.n_slices)     = in_n_slices;
          
          out.create_mat();
          }
        }
      }
    else
      {
      unwrap_cube_check< Cube<eT> > tmp(A, out);
      const Cube<eT>& B           = tmp.M;
      
      out.set_size(in_n_rows, in_n_cols, in_n_slices);
      
      eT* out_mem = out.memptr();
      uword i = 0;
      
      const uword B_n_rows   = B.n_rows;
      const uword B_n_cols   = B.n_cols;
      const uword B_n_slices = B.n_slices;
      
      for(uword slice=0; slice<B_n_slices; ++slice)
        {
        for(uword row=0; row<B_n_rows; ++row)
          {
          for(uword col=0; col<B_n_cols; ++col)
            {
            out_mem[i] = B.at(row,col,slice);
            ++i;
            }
          }
        }
        
      }
    }
  else
    {
    const unwrap_cube_check< Cube<eT> > tmp(A, out);
    const Cube<eT>& B                 = tmp.M;
    
    const uword n_elem_to_copy = (std::min)(B.n_elem, in_n_elem);
    
    out.set_size(in_n_rows, in_n_cols, in_n_slices);
    
    eT* out_mem = out.memptr();
    
    if(in_dim == 0)
      {
      arrayops::copy( out_mem, B.memptr(), n_elem_to_copy );
      }
    else
      {
      uword row   = 0;
      uword col   = 0;
      uword slice = 0;
      
      const uword B_n_rows = B.n_rows;
      const uword B_n_cols = B.n_cols;
      
      for(uword i=0; i<n_elem_to_copy; ++i)
        {
        out_mem[i] = B.at(row,col,slice);
        
        ++col;
        
        if(col >= B_n_cols)
          {
          col = 0;
          ++row;
          
          if(row >= B_n_rows)
            {
            row = 0;
            ++slice;
            }
          }
        }
      }
    
    for(uword i=n_elem_to_copy; i<in_n_elem; ++i)
      {
      out_mem[i] = eT(0);
      }
    
    }
  }



//! @}