ronw@642
|
1 // Copyright 2013, Google, Inc.
|
ronw@642
|
2 // Author: Ron Weiss <ronw@google.com>
|
ronw@642
|
3 //
|
ronw@642
|
4 // This C++ file is part of an implementation of Lyon's cochlear model:
|
ronw@642
|
5 // "Cascade of Asymmetric Resonators with Fast-Acting Compression"
|
ronw@642
|
6 // to supplement Lyon's upcoming book "Human and Machine Hearing"
|
ronw@642
|
7 //
|
ronw@642
|
8 // Licensed under the Apache License, Version 2.0 (the "License");
|
ronw@642
|
9 // you may not use this file except in compliance with the License.
|
ronw@642
|
10 // You may obtain a copy of the License at
|
ronw@642
|
11 //
|
ronw@642
|
12 // http://www.apache.org/licenses/LICENSE-2.0
|
ronw@642
|
13 //
|
ronw@642
|
14 // Unless required by applicable law or agreed to in writing, software
|
ronw@642
|
15 // distributed under the License is distributed on an "AS IS" BASIS,
|
ronw@642
|
16 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
ronw@642
|
17 // See the License for the specific language governing permissions and
|
ronw@642
|
18 // limitations under the License.
|
ronw@642
|
19
|
ronw@642
|
20 #include "sai.h"
|
ronw@642
|
21
|
ronw@642
|
22 #include <iostream>
|
ronw@642
|
23 #include <vector>
|
ronw@642
|
24
|
ronw@642
|
25 #include "gtest/gtest.h"
|
ronw@642
|
26
|
ronw@642
|
27 using testing::Values;
|
ronw@642
|
28 using std::vector;
|
ronw@642
|
29
|
ronw@642
|
30 vector<FloatArray> CreateZeroSegment(int n_ch, int length) {
|
ronw@642
|
31 vector<FloatArray> segment;
|
ronw@642
|
32 for (int i = 0; i < length; ++i) {
|
ronw@642
|
33 segment.push_back(FloatArray::Zero(n_ch));
|
ronw@642
|
34 }
|
ronw@642
|
35 return segment;
|
ronw@642
|
36 }
|
ronw@642
|
37
|
ronw@642
|
38 bool HasPeakAt(const Float2dArray& frame, int index) {
|
ronw@642
|
39 if (index == 0) {
|
ronw@642
|
40 return frame(index) > frame(index + 1);
|
ronw@642
|
41 } else if (index == frame.size() - 1) {
|
ronw@642
|
42 return frame(index) > frame(index - 1);
|
ronw@642
|
43 }
|
ronw@642
|
44 return frame(index) > frame(index + 1) && frame(index) > frame(index - 1);
|
ronw@642
|
45 }
|
ronw@642
|
46
|
ronw@642
|
47 class SAIPeriodicInputTest
|
ronw@642
|
48 : public testing::TestWithParam<std::tr1::tuple<int, int>> {
|
ronw@642
|
49 protected:
|
ronw@642
|
50 void SetUp() {
|
ronw@642
|
51 period_ = std::tr1::get<0>(GetParam());
|
ronw@642
|
52 phase_ = std::tr1::get<1>(GetParam());
|
ronw@642
|
53 }
|
ronw@642
|
54
|
ronw@642
|
55 int period_;
|
ronw@642
|
56 int phase_;
|
ronw@642
|
57 };
|
ronw@642
|
58
|
ronw@642
|
59 TEST_P(SAIPeriodicInputTest, SingleChannelPulseTrain) {
|
ronw@642
|
60 vector<FloatArray> segment = CreateZeroSegment(1, 38);
|
ronw@642
|
61 for (int i = phase_; i < segment.size(); i += period_) {
|
ronw@642
|
62 segment[i](0) = 1;
|
ronw@642
|
63 }
|
ronw@642
|
64
|
ronw@642
|
65 SAIParams sai_params;
|
ronw@642
|
66 sai_params.window_width = segment.size();
|
ronw@642
|
67 sai_params.n_ch = 1;
|
ronw@642
|
68 sai_params.width = 15;
|
ronw@642
|
69 // Half of the SAI should come from the future.
|
ronw@642
|
70 // sai_params.future_lags = sai_params.width / 2;
|
ronw@642
|
71 sai_params.future_lags = 0;
|
ronw@642
|
72 sai_params.n_window_pos = 2;
|
ronw@642
|
73
|
ronw@642
|
74 SAI sai(sai_params);
|
ronw@642
|
75 Float2dArray sai_frame;
|
ronw@642
|
76 sai.RunSegment(segment, &sai_frame);
|
ronw@642
|
77
|
ronw@642
|
78 // The output should have peaks at the same positions, regardless of
|
ronw@642
|
79 // input phase.
|
ronw@642
|
80 for (int i = sai_frame.size() - 1; i >= 0 ; i -= period_) {
|
ronw@642
|
81 EXPECT_TRUE(HasPeakAt(sai_frame, i));
|
ronw@642
|
82 }
|
ronw@642
|
83
|
ronw@642
|
84 for (int i = 0; i < segment.size(); ++i) {
|
ronw@642
|
85 std::cout << segment[i](0) << " ";
|
ronw@642
|
86 }
|
ronw@642
|
87 std::cout << "\n";
|
ronw@642
|
88 for (int i = 0; i < sai_frame.size(); ++i) {
|
ronw@642
|
89 std::cout << sai_frame(i) << " ";
|
ronw@642
|
90 }
|
ronw@642
|
91 std::cout << "\n";
|
ronw@642
|
92 }
|
ronw@642
|
93 INSTANTIATE_TEST_CASE_P(PeriodicInputVariations, SAIPeriodicInputTest,
|
ronw@642
|
94 testing::Combine(Values(25, 10, 5, 2), // periods.
|
ronw@642
|
95 Values(0, 3))); // phases.
|