To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Tag: | Revision:

root / pitch-track-align / pitch-track-align.sml

History | View | Annotate | Download (8.71 KB)

1

    
2
datatype pitch_direction =
3
         PITCH_NONE |
4
         PITCH_UP of real |
5
         PITCH_DOWN of real
6

    
7
type value = pitch_direction
8
type cost = real
9

    
10
fun choose costs =
11
    case costs of
12
        (NONE,   NONE,   _) => 0.0
13
      | (SOME a, NONE,   _) => a
14
      | (NONE,   SOME b, _) => b
15
      | (SOME _, SOME _, NONE) => raise Fail "Internal error"
16
      | (SOME a, SOME b, SOME both) =>
17
        if a < b then
18
            if both <= a then both else a
19
        else
20
            if both <= b then both else b
21

    
22
fun cost (p1, p2) =
23
    let fun together a b = let val diff = Real.abs (a - b) in 
24
                               if diff < 1.0 then ~1.0
25
                               else if diff > 3.0 then 1.0
26
                               else 0.0 
27
                           end
28
        fun opposing a b = let val diff = a + b in
29
                               if diff < 2.0 then 1.0
30
                               else 2.0
31
                           end
32
    in
33
        case (p1, p2) of
34
            (PITCH_NONE, PITCH_NONE) => 0.0
35
          | (PITCH_UP a, PITCH_UP b) => together a b
36
          | (PITCH_UP a, PITCH_DOWN b) => opposing a b
37
          | (PITCH_DOWN a, PITCH_UP b) => opposing a b
38
          | (PITCH_DOWN a, PITCH_DOWN b) => together a b
39
          | _ => 1.0
40
    end
41
       
42
fun costSeries (s1 : value vector) (s2 : value vector) : cost vector vector =
43
    let open Vector
44

    
45
        fun costSeries' (rowAcc : cost vector list) j =
46
            if j = length s1
47
            then fromList (rev rowAcc)
48
            else costSeries' (costRow' rowAcc j [] 0 :: rowAcc) (j+1)
49

    
50
        and costRow' (rowAcc : cost vector list) j (colAcc : cost list) i =
51
            if i = length s2
52
            then fromList (rev colAcc)
53
            else let val c = cost (sub (s1, j), sub (s2, i))
54
                     val options =
55
                         (if null rowAcc
56
                          then NONE
57
                          else SOME (c + sub (hd rowAcc, i)),
58
                          if i = 0
59
                          then NONE
60
                          else SOME (c + hd colAcc),
61
                          if null rowAcc orelse i = 0
62
                          then NONE
63
                          else SOME (c + sub (hd rowAcc, i-1)))
64
                 in
65
                     costRow' rowAcc j (choose options :: colAcc) (i+1)
66
                 end
67
    in
68
        costSeries' [] 0
69
    end
70

    
71
fun alignSeries s1 s2 =
72
    let val cumulativeCosts = costSeries s1 s2
73
(*        val _ = let open TextIO in
74
                    output (stdErr, "Cost matrix:\n");
75
                    Vector.app
76
                        (fn v =>
77
                            (Vector.app
78
                                 (fn x => output (stdErr, Real.toString x ^ " ")) v;
79
                             output (stdErr, "\n")))
80
                        cumulativeCosts
81
                end
82
*)
83
        fun cost (j, i) = Vector.sub (Vector.sub (cumulativeCosts, j), i)
84
        fun trace (j, i) acc =
85
            if i = 0
86
            then if j = 0
87
                 then i :: acc
88
                 else trace (j-1, i) (i :: acc)
89
            else if j = 0
90
            then trace (j, i-1) acc
91
            else let val (a, b, both) =
92
                         (cost (j-1, i), cost (j, i-1), cost (j-1, i-1))
93
                 in
94
                     if a < b then
95
                         if both <= a
96
                         then trace (j-1, i-1) (i :: acc)
97
                         else trace (j-1, i) (i :: acc)
98
                     else
99
                         if both <= b
100
                         then trace (j-1, i-1) (i :: acc)
101
                         else trace (j, i-1) acc
102
                 end
103

    
104
        val sj = Vector.length s1
105
        val si = Vector.length s2
106
    in
107
        Vector.fromList
108
            (if si = 0 orelse sj = 0
109
             then []
110
             else trace (sj-1, si-1) [])
111
    end
112

    
113
fun preprocess (times : real list, frequencies : real list) :
114
    real vector * value vector * real vector =
115
    let val pitches =
116
            map (fn f =>
117
                    if f < 0.0
118
                    then 0.0
119
                    else Real.realRound (12.0 * (Math.log10(f / 220.0) /
120
                                                 Math.log10(2.0)) + 57.0))
121
                frequencies
122
        val values =
123
            let val acc =
124
                    foldl (fn (p, (acc, prev)) =>
125
                              if p <= 0.0 then (PITCH_NONE :: acc, prev)
126
                              else if prev <= 0.0
127
                              then (PITCH_UP 0.0 :: acc, p)
128
                              else if p >= prev
129
                              then (PITCH_UP (p - prev) :: acc, p)
130
                              else (PITCH_DOWN (prev - p) :: acc, p))
131
                          ([], 0.0)
132
                          pitches
133
            in
134
                rev (#1 acc)
135
            end
136
(*        val _ =
137
            app (fn (text, p) =>
138
                    TextIO.output (TextIO.stdErr, ("[" ^ text ^ "] -> " ^
139
                                                   Real.toString p ^ "\n")))
140
                (ListPair.map (fn (PITCH_NONE, p) => (" ", p)
141
                                | (PITCH_UP d, p) => ("+", p)
142
                                | (PITCH_DOWN d, p) => ("-", p))
143
                              (values, pitches))
144
        val _ = TextIO.output (TextIO.stdErr, "(end)\n");
145
 *)
146
        val _ =
147
            app (fn v =>
148
                    TextIO.output (TextIO.stdErr,
149
                                   (case v of
150
                                        PITCH_NONE => "=0"
151
                                      | PITCH_UP d => "+" ^ Real.toString d
152
                                      | PITCH_DOWN d => "-" ^ Real.toString d)
153
                                   ^ " "))
154
                values
155
        val _ = TextIO.output (TextIO.stdErr, " (end)\n");
156
    in
157
        (Vector.fromList times,
158
         Vector.fromList values,
159
         Vector.fromList pitches)
160
    end
161
    
162
fun read csvFile =
163
    let fun toNumberPair line =
164
            case String.fields (fn c => c = #",") line of
165
                a::b::_ => (case (Real.fromString a, Real.fromString b) of
166
                                (SOME r1, SOME r2) => (r1, r2)
167
                              | _ => raise Fail ("Failed to parse numbers: " ^
168
                                                 line))
169
              | _ => raise Fail ("Not enough columns: " ^ line)
170
        fun read' s acc =
171
            case TextIO.inputLine s of
172
                SOME line =>
173
                let val pair = toNumberPair
174
                                   (String.substring
175
                                        (line, 0, String.size line - 1))
176
                in
177
                    read' s (pair :: acc)
178
                end
179
              | NONE => rev acc
180
        val stream = TextIO.openIn csvFile
181
        val (timeList, freqList) = ListPair.unzip (read' stream [])
182
        val _ = TextIO.closeIn stream
183
    in
184
        preprocess (timeList, freqList)
185
    end
186

    
187
fun meanDiff pitches1 pitches2 mapping =
188
    let open Vector
189
        val n = length mapping
190
        val sumDiff =
191
            foldli (fn (i, j, acc) => acc +
192
                                      sub (pitches1, i) -
193
                                      sub (pitches2, j))
194
                   0.0 mapping
195
    in
196
        if n = 0 then 0.0
197
        else sumDiff / Real.fromInt n
198
    end
199
        
200
fun alignFiles csv1 csv2 =
201
    let val (times1, values1, pitches1) = read csv1
202
        val (times2, values2, pitches2) = read csv2
203
        (* raw alignment returns the index into pitches2 for each
204
           element in pitches1 *)
205
        val raw = alignSeries values1 values2
206
        val _ = TextIO.output (TextIO.stdErr, "DTW output:\n")
207
        val _ = Vector.app
208
                    (fn i => TextIO.output (TextIO.stdErr, Int.toString i ^ " "))
209
                    raw
210
        val _ = TextIO.output (TextIO.stdErr, "\n")
211
        val _ = TextIO.output (TextIO.stdErr,
212
                               "Mean pitch difference: reference " ^
213
                               Real.toString (meanDiff pitches1 pitches2 raw)
214
                               ^ " semitones higher than other track\n")
215
    in
216
        List.tabulate (Vector.length raw,
217
                       fn i => (Vector.sub (times1, i),
218
                                Vector.sub (times2, Vector.sub (raw, i))))
219
    end
220

    
221
fun printAlignment alignment =
222
    app (fn (from, to) =>
223
            print (Real.toString from ^ "," ^ Real.toString to ^ "\n"))
224
        alignment
225
        
226
fun usage () =
227
    TextIO.output (TextIO.stdErr,
228
                   "Usage: pitch-track-align pitch1.csv pitch2.csv\n")
229

    
230
fun main () =
231
    (case CommandLine.arguments () of
232
         [csv1, csv2] => printAlignment (alignFiles csv1 csv2)
233
       | _ => usage ())
234
    handle exn => 
235
           (TextIO.output (TextIO.stdErr, "Error: " ^ (exnMessage exn) ^ "\n");
236
            OS.Process.exit OS.Process.failure)