changeset 649:461d4374b6d9

Test SAI with multi-channel input.
author ronw@google.com
date Tue, 11 Jun 2013 22:05:10 +0000
parents 1c2a5868f23a
children f926e0892dee
files carfac/sai.h carfac/sai_test.cc
diffstat 2 files changed, 33 insertions(+), 18 deletions(-) [+]
line wrap: on
line diff
--- a/carfac/sai.h	Tue Jun 11 21:41:53 2013 +0000
+++ b/carfac/sai.h	Tue Jun 11 22:05:10 2013 +0000
@@ -51,7 +51,7 @@
  public:
   explicit SAI(const SAIParams& params);
 
-  // Fill output_frame with a params_.n_ch by params_.width SAI frame
+  // Fills output_frame with a params_.n_ch by params_.width SAI frame
   // computed from the given input frames.
   //
   // The input should have dimensionality of params_.window_width by
@@ -61,7 +61,7 @@
                   ArrayXX* output_output_frame);
 
  private:
-  // Process successive windows within input_buffer, choose trigger
+  // Processes successive windows within input_buffer, choose trigger
   // points, and blend each window into output_buffer.
   void StabilizeSegment(const ArrayXX& input_buffer,
                         ArrayXX* output_buffer) const;
--- a/carfac/sai_test.cc	Tue Jun 11 21:41:53 2013 +0000
+++ b/carfac/sai_test.cc	Tue Jun 11 22:05:10 2013 +0000
@@ -35,7 +35,7 @@
   return segment;
 }
 
-bool HasPeakAt(const ArrayXX& frame, int index) {
+bool HasPeakAt(const ArrayX& frame, int index) {
   if (index == 0) {
     return frame(index) > frame(index + 1);
   } else if (index == frame.size() - 1) {
@@ -49,22 +49,26 @@
  protected:
   void SetUp() {
     period_ = std::tr1::get<0>(GetParam());
-    phase_ = std::tr1::get<1>(GetParam());
+    n_ch_ = std::tr1::get<1>(GetParam());
   }
 
   int period_;
-  int phase_;
+  int n_ch_;
 };
 
-TEST_P(SAIPeriodicInputTest, SingleChannelPulseTrain) {
-  vector<ArrayX> segment = CreateZeroSegment(1, 38);
-  for (int i = phase_; i < segment.size(); i += period_) {
-    segment[i](0) = 1;
+TEST_P(SAIPeriodicInputTest, MultiChannelPulseTrain) {
+  vector<ArrayX> segment = CreateZeroSegment(n_ch_, 38);
+  for (int i = 0; i < n_ch_; ++i) {
+    // Begin each channel at a different phase.
+    const int phase = i;
+    for (int j = phase; j < segment.size(); j += period_) {
+      segment[j](i) = 1;
+    }
   }
 
   SAIParams sai_params;
   sai_params.window_width = segment.size();
-  sai_params.n_ch = 1;
+  sai_params.n_ch = n_ch_;
   sai_params.width = 15;
   // Half of the SAI should come from the future.
   // sai_params.future_lags = sai_params.width / 2;
@@ -77,19 +81,30 @@
 
   // The output should have peaks at the same positions, regardless of
   // input phase.
-  for (int i = sai_frame.size() - 1; i >= 0 ; i -= period_) {
-    EXPECT_TRUE(HasPeakAt(sai_frame, i));
+  for (int i = 0; i < n_ch_; ++i) {
+    const ArrayX& sai_channel = sai_frame.row(i);
+    for (int j = sai_channel.size() - 1; j >= 0; j -= period_) {
+      EXPECT_TRUE(HasPeakAt(sai_channel, j));
+    }
   }
 
-  for (int i = 0; i < segment.size(); ++i) {
-    std::cout << segment[i](0) << " ";
+  std::cout << "Input:\n";
+  for (int i = 0; i < n_ch_; ++i) {
+    for (int j = 0; j < segment.size(); ++j) {
+      std::cout << segment[j](i) << " ";
+    }
+    std::cout << "\n";
   }
-  std::cout << "\n";
-  for (int i = 0; i < sai_frame.size(); ++i) {
-    std::cout << sai_frame(i) << " ";
+
+  std::cout << "Output:\n";
+  for (int i = 0; i < sai_frame.rows(); ++i) {
+    for (int j = 0; j < sai_frame.cols(); ++j) {
+      std::cout << sai_frame(i, j) << " ";
+    }
+    std::cout << "\n";
   }
   std::cout << "\n";
 }
 INSTANTIATE_TEST_CASE_P(PeriodicInputVariations, SAIPeriodicInputTest,
                         testing::Combine(Values(25, 10, 5, 2),  // periods.
-                                         Values(0, 3)));  // phases.
+                                         Values(1, 2, 15)));  // n_ch.