Chris@87
|
1 """
|
Chris@87
|
2 Utilities that manipulate strides to achieve desirable effects.
|
Chris@87
|
3
|
Chris@87
|
4 An explanation of strides can be found in the "ndarray.rst" file in the
|
Chris@87
|
5 NumPy reference guide.
|
Chris@87
|
6
|
Chris@87
|
7 """
|
Chris@87
|
8 from __future__ import division, absolute_import, print_function
|
Chris@87
|
9
|
Chris@87
|
10 import numpy as np
|
Chris@87
|
11
|
Chris@87
|
12 __all__ = ['broadcast_arrays']
|
Chris@87
|
13
|
Chris@87
|
14 class DummyArray(object):
|
Chris@87
|
15 """Dummy object that just exists to hang __array_interface__ dictionaries
|
Chris@87
|
16 and possibly keep alive a reference to a base array.
|
Chris@87
|
17 """
|
Chris@87
|
18
|
Chris@87
|
19 def __init__(self, interface, base=None):
|
Chris@87
|
20 self.__array_interface__ = interface
|
Chris@87
|
21 self.base = base
|
Chris@87
|
22
|
Chris@87
|
23 def as_strided(x, shape=None, strides=None):
|
Chris@87
|
24 """ Make an ndarray from the given array with the given shape and strides.
|
Chris@87
|
25 """
|
Chris@87
|
26 interface = dict(x.__array_interface__)
|
Chris@87
|
27 if shape is not None:
|
Chris@87
|
28 interface['shape'] = tuple(shape)
|
Chris@87
|
29 if strides is not None:
|
Chris@87
|
30 interface['strides'] = tuple(strides)
|
Chris@87
|
31 array = np.asarray(DummyArray(interface, base=x))
|
Chris@87
|
32 # Make sure dtype is correct in case of custom dtype
|
Chris@87
|
33 if array.dtype.kind == 'V':
|
Chris@87
|
34 array.dtype = x.dtype
|
Chris@87
|
35 return array
|
Chris@87
|
36
|
Chris@87
|
37 def broadcast_arrays(*args):
|
Chris@87
|
38 """
|
Chris@87
|
39 Broadcast any number of arrays against each other.
|
Chris@87
|
40
|
Chris@87
|
41 Parameters
|
Chris@87
|
42 ----------
|
Chris@87
|
43 `*args` : array_likes
|
Chris@87
|
44 The arrays to broadcast.
|
Chris@87
|
45
|
Chris@87
|
46 Returns
|
Chris@87
|
47 -------
|
Chris@87
|
48 broadcasted : list of arrays
|
Chris@87
|
49 These arrays are views on the original arrays. They are typically
|
Chris@87
|
50 not contiguous. Furthermore, more than one element of a
|
Chris@87
|
51 broadcasted array may refer to a single memory location. If you
|
Chris@87
|
52 need to write to the arrays, make copies first.
|
Chris@87
|
53
|
Chris@87
|
54 Examples
|
Chris@87
|
55 --------
|
Chris@87
|
56 >>> x = np.array([[1,2,3]])
|
Chris@87
|
57 >>> y = np.array([[1],[2],[3]])
|
Chris@87
|
58 >>> np.broadcast_arrays(x, y)
|
Chris@87
|
59 [array([[1, 2, 3],
|
Chris@87
|
60 [1, 2, 3],
|
Chris@87
|
61 [1, 2, 3]]), array([[1, 1, 1],
|
Chris@87
|
62 [2, 2, 2],
|
Chris@87
|
63 [3, 3, 3]])]
|
Chris@87
|
64
|
Chris@87
|
65 Here is a useful idiom for getting contiguous copies instead of
|
Chris@87
|
66 non-contiguous views.
|
Chris@87
|
67
|
Chris@87
|
68 >>> [np.array(a) for a in np.broadcast_arrays(x, y)]
|
Chris@87
|
69 [array([[1, 2, 3],
|
Chris@87
|
70 [1, 2, 3],
|
Chris@87
|
71 [1, 2, 3]]), array([[1, 1, 1],
|
Chris@87
|
72 [2, 2, 2],
|
Chris@87
|
73 [3, 3, 3]])]
|
Chris@87
|
74
|
Chris@87
|
75 """
|
Chris@87
|
76 args = [np.asarray(_m) for _m in args]
|
Chris@87
|
77 shapes = [x.shape for x in args]
|
Chris@87
|
78 if len(set(shapes)) == 1:
|
Chris@87
|
79 # Common case where nothing needs to be broadcasted.
|
Chris@87
|
80 return args
|
Chris@87
|
81 shapes = [list(s) for s in shapes]
|
Chris@87
|
82 strides = [list(x.strides) for x in args]
|
Chris@87
|
83 nds = [len(s) for s in shapes]
|
Chris@87
|
84 biggest = max(nds)
|
Chris@87
|
85 # Go through each array and prepend dimensions of length 1 to each of
|
Chris@87
|
86 # the shapes in order to make the number of dimensions equal.
|
Chris@87
|
87 for i in range(len(args)):
|
Chris@87
|
88 diff = biggest - nds[i]
|
Chris@87
|
89 if diff > 0:
|
Chris@87
|
90 shapes[i] = [1] * diff + shapes[i]
|
Chris@87
|
91 strides[i] = [0] * diff + strides[i]
|
Chris@87
|
92 # Chech each dimension for compatibility. A dimension length of 1 is
|
Chris@87
|
93 # accepted as compatible with any other length.
|
Chris@87
|
94 common_shape = []
|
Chris@87
|
95 for axis in range(biggest):
|
Chris@87
|
96 lengths = [s[axis] for s in shapes]
|
Chris@87
|
97 unique = set(lengths + [1])
|
Chris@87
|
98 if len(unique) > 2:
|
Chris@87
|
99 # There must be at least two non-1 lengths for this axis.
|
Chris@87
|
100 raise ValueError("shape mismatch: two or more arrays have "
|
Chris@87
|
101 "incompatible dimensions on axis %r." % (axis,))
|
Chris@87
|
102 elif len(unique) == 2:
|
Chris@87
|
103 # There is exactly one non-1 length. The common shape will take
|
Chris@87
|
104 # this value.
|
Chris@87
|
105 unique.remove(1)
|
Chris@87
|
106 new_length = unique.pop()
|
Chris@87
|
107 common_shape.append(new_length)
|
Chris@87
|
108 # For each array, if this axis is being broadcasted from a
|
Chris@87
|
109 # length of 1, then set its stride to 0 so that it repeats its
|
Chris@87
|
110 # data.
|
Chris@87
|
111 for i in range(len(args)):
|
Chris@87
|
112 if shapes[i][axis] == 1:
|
Chris@87
|
113 shapes[i][axis] = new_length
|
Chris@87
|
114 strides[i][axis] = 0
|
Chris@87
|
115 else:
|
Chris@87
|
116 # Every array has a length of 1 on this axis. Strides can be
|
Chris@87
|
117 # left alone as nothing is broadcasted.
|
Chris@87
|
118 common_shape.append(1)
|
Chris@87
|
119
|
Chris@87
|
120 # Construct the new arrays.
|
Chris@87
|
121 broadcasted = [as_strided(x, shape=sh, strides=st) for (x, sh, st) in
|
Chris@87
|
122 zip(args, shapes, strides)]
|
Chris@87
|
123 return broadcasted
|