max@0: // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) max@0: // Copyright (C) 2008-2011 Conrad Sanderson max@0: // max@0: // This file is part of the Armadillo C++ library. max@0: // It is provided without any warranty of fitness max@0: // for any purpose. You can redistribute this file max@0: // and/or modify it under the terms of the GNU max@0: // Lesser General Public License (LGPL) as published max@0: // by the Free Software Foundation, either version 3 max@0: // of the License or (at your option) any later version. max@0: // (see http://www.opensource.org/licenses for more info) max@0: max@0: max@0: //! \addtogroup op_sort max@0: //! @{ max@0: max@0: max@0: max@0: template max@0: class arma_ascend_sort_helper max@0: { max@0: public: max@0: max@0: arma_inline max@0: bool max@0: operator() (eT a, eT b) const max@0: { max@0: return (a < b); max@0: } max@0: }; max@0: max@0: max@0: max@0: template max@0: class arma_descend_sort_helper max@0: { max@0: public: max@0: max@0: arma_inline max@0: bool max@0: operator() (eT a, eT b) const max@0: { max@0: return (a > b); max@0: } max@0: }; max@0: max@0: max@0: max@0: template max@0: class arma_ascend_sort_helper< std::complex > max@0: { max@0: public: max@0: max@0: typedef typename std::complex eT; max@0: max@0: inline max@0: bool max@0: operator() (const eT& a, const eT& b) const max@0: { max@0: return (std::abs(a) < std::abs(b)); max@0: } max@0: }; max@0: max@0: max@0: max@0: template max@0: class arma_descend_sort_helper< std::complex > max@0: { max@0: public: max@0: max@0: typedef typename std::complex eT; max@0: max@0: inline max@0: bool max@0: operator() (const eT& a, const eT& b) const max@0: { max@0: return (std::abs(a) > std::abs(b)); max@0: } max@0: }; max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: op_sort::direct_sort(eT* X, const uword n_elem, const uword sort_type) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: if(sort_type == 0) max@0: { max@0: arma_ascend_sort_helper comparator; max@0: max@0: std::sort(&X[0], &X[n_elem], comparator); max@0: } max@0: else max@0: { max@0: arma_descend_sort_helper comparator; max@0: max@0: std::sort(&X[0], &X[n_elem], comparator); max@0: } max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: op_sort::copy_row(eT* X, const Mat& A, const uword row) max@0: { max@0: const uword N = A.n_cols; max@0: max@0: uword i,j; max@0: max@0: for(i=0, j=1; j max@0: inline max@0: void max@0: op_sort::copy_row(Mat& A, const eT* X, const uword row) max@0: { max@0: const uword N = A.n_cols; max@0: max@0: uword i,j; max@0: max@0: for(i=0, j=1; j max@0: inline max@0: void max@0: op_sort::apply(Mat& out, const Op& in) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const unwrap tmp(in.m); max@0: const Mat& X = tmp.M; max@0: max@0: const uword sort_type = in.aux_uword_a; max@0: const uword dim = in.aux_uword_b; max@0: max@0: arma_debug_check( (sort_type > 1), "sort(): incorrect usage. sort_type must be 0 or 1"); max@0: arma_debug_check( (dim > 1), "sort(): incorrect usage. dim must be 0 or 1" ); max@0: arma_debug_check( (X.is_finite() == false), "sort(): given object has non-finite elements" ); max@0: max@0: if( (X.n_rows * X.n_cols) <= 1 ) max@0: { max@0: out = X; max@0: return; max@0: } max@0: max@0: max@0: if(dim == 0) // sort the contents of each column max@0: { max@0: arma_extra_debug_print("op_sort::apply(), dim = 0"); max@0: max@0: out = X; max@0: max@0: const uword n_rows = out.n_rows; max@0: const uword n_cols = out.n_cols; max@0: max@0: for(uword col=0; col < n_cols; ++col) max@0: { max@0: op_sort::direct_sort( out.colptr(col), n_rows, sort_type ); max@0: } max@0: } max@0: else max@0: if(dim == 1) // sort the contents of each row max@0: { max@0: if(X.n_rows == 1) // a row vector max@0: { max@0: arma_extra_debug_print("op_sort::apply(), dim = 1, vector specific"); max@0: max@0: out = X; max@0: op_sort::direct_sort(out.memptr(), out.n_elem, sort_type); max@0: } max@0: else // not a row vector max@0: { max@0: arma_extra_debug_print("op_sort::apply(), dim = 1, generic"); max@0: max@0: out.copy_size(X); max@0: max@0: const uword n_rows = out.n_rows; max@0: const uword n_cols = out.n_cols; max@0: max@0: podarray tmp_array(n_cols); max@0: max@0: for(uword row=0; row < n_rows; ++row) max@0: { max@0: op_sort::copy_row(tmp_array.memptr(), X, row); max@0: max@0: op_sort::direct_sort( tmp_array.memptr(), n_cols, sort_type ); max@0: max@0: op_sort::copy_row(out, tmp_array.memptr(), row); max@0: } max@0: } max@0: } max@0: max@0: } max@0: max@0: max@0: //! @}