Chris@87
|
1 from __future__ import division, absolute_import, print_function
|
Chris@87
|
2
|
Chris@87
|
3 import numpy as np
|
Chris@87
|
4 from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal
|
Chris@87
|
5 from numpy.testing import assert_array_equal
|
Chris@87
|
6 import threading
|
Chris@87
|
7 import sys
|
Chris@87
|
8 if sys.version_info[0] >= 3:
|
Chris@87
|
9 import queue
|
Chris@87
|
10 else:
|
Chris@87
|
11 import Queue as queue
|
Chris@87
|
12
|
Chris@87
|
13
|
Chris@87
|
14 def fft1(x):
|
Chris@87
|
15 L = len(x)
|
Chris@87
|
16 phase = -2j*np.pi*(np.arange(L)/float(L))
|
Chris@87
|
17 phase = np.arange(L).reshape(-1, 1) * phase
|
Chris@87
|
18 return np.sum(x*np.exp(phase), axis=1)
|
Chris@87
|
19
|
Chris@87
|
20
|
Chris@87
|
21 class TestFFTShift(TestCase):
|
Chris@87
|
22
|
Chris@87
|
23 def test_fft_n(self):
|
Chris@87
|
24 self.assertRaises(ValueError, np.fft.fft, [1, 2, 3], 0)
|
Chris@87
|
25
|
Chris@87
|
26
|
Chris@87
|
27 class TestFFT1D(TestCase):
|
Chris@87
|
28
|
Chris@87
|
29 def test_basic(self):
|
Chris@87
|
30 rand = np.random.random
|
Chris@87
|
31 x = rand(30) + 1j*rand(30)
|
Chris@87
|
32 assert_array_almost_equal(fft1(x), np.fft.fft(x))
|
Chris@87
|
33
|
Chris@87
|
34
|
Chris@87
|
35 class TestFFTThreadSafe(TestCase):
|
Chris@87
|
36 threads = 16
|
Chris@87
|
37 input_shape = (800, 200)
|
Chris@87
|
38
|
Chris@87
|
39 def _test_mtsame(self, func, *args):
|
Chris@87
|
40 def worker(args, q):
|
Chris@87
|
41 q.put(func(*args))
|
Chris@87
|
42
|
Chris@87
|
43 q = queue.Queue()
|
Chris@87
|
44 expected = func(*args)
|
Chris@87
|
45
|
Chris@87
|
46 # Spin off a bunch of threads to call the same function simultaneously
|
Chris@87
|
47 t = [threading.Thread(target=worker, args=(args, q))
|
Chris@87
|
48 for i in range(self.threads)]
|
Chris@87
|
49 [x.start() for x in t]
|
Chris@87
|
50
|
Chris@87
|
51 [x.join() for x in t]
|
Chris@87
|
52 # Make sure all threads returned the correct value
|
Chris@87
|
53 for i in range(self.threads):
|
Chris@87
|
54 assert_array_equal(q.get(timeout=5), expected,
|
Chris@87
|
55 'Function returned wrong value in multithreaded context')
|
Chris@87
|
56
|
Chris@87
|
57 def test_fft(self):
|
Chris@87
|
58 a = np.ones(self.input_shape) * 1+0j
|
Chris@87
|
59 self._test_mtsame(np.fft.fft, a)
|
Chris@87
|
60
|
Chris@87
|
61 def test_ifft(self):
|
Chris@87
|
62 a = np.ones(self.input_shape) * 1+0j
|
Chris@87
|
63 self._test_mtsame(np.fft.ifft, a)
|
Chris@87
|
64
|
Chris@87
|
65 def test_rfft(self):
|
Chris@87
|
66 a = np.ones(self.input_shape)
|
Chris@87
|
67 self._test_mtsame(np.fft.rfft, a)
|
Chris@87
|
68
|
Chris@87
|
69 def test_irfft(self):
|
Chris@87
|
70 a = np.ones(self.input_shape) * 1+0j
|
Chris@87
|
71 self._test_mtsame(np.fft.irfft, a)
|
Chris@87
|
72
|
Chris@87
|
73
|
Chris@87
|
74 if __name__ == "__main__":
|
Chris@87
|
75 run_module_suite()
|