Mercurial > hg > chourdakisreiss2018smc
comparison demo/text2annotation.py @ 0:90155bdd5dd6
first commit
author | Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk> |
---|---|
date | Wed, 16 May 2018 18:27:05 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:90155bdd5dd6 |
---|---|
1 #!/usr/bin/env python3 | |
2 # -*- coding: utf-8 -*- | |
3 """ | |
4 Created on Sat Apr 28 14:17:15 2018 | |
5 | |
6 @author: Emmanouil Theofanis Chourdakis | |
7 | |
8 Takes a .txt story and annotates it based on: | |
9 | |
10 characters, | |
11 places, | |
12 saywords, | |
13 character_lines, | |
14 spatial_indicators, | |
15 | |
16 @output: | |
17 .ann file with the same name | |
18 .json file with the extracted character lines | |
19 | |
20 """ | |
21 | |
22 import os | |
23 import argparse | |
24 from sklearn.externals import joblib | |
25 import ner | |
26 import spacy | |
27 import re | |
28 import logging | |
29 import json | |
30 from difflib import SequenceMatcher | |
31 from neuralcoref import Coref | |
32 from rel import * | |
33 | |
34 def pronoun2gender(word): | |
35 pronoun2gender = { | |
36 'he' : 'Male', | |
37 'him': 'Male', | |
38 'she': 'Female', | |
39 'her': 'Female', | |
40 'his': 'Male', | |
41 'hers': 'Female', | |
42 'himself': 'Male', | |
43 'herself': 'Female', | |
44 } | |
45 | |
46 if word in pronoun2gender: | |
47 return pronoun2gender[word] | |
48 else: | |
49 return 'neutral' | |
50 | |
51 | |
52 logging.basicConfig(level=logging.INFO) | |
53 | |
54 # given an iterable of pairs return the key corresponding to the greatest value | |
55 def argmax(pairs): | |
56 #https://stackoverflow.com/questions/5098580/implementing-argmax-in-python | |
57 return max(pairs, key=lambda x: x[1])[0] | |
58 | |
59 # given an iterable of values return the index of the greatest value | |
60 def argmax_index(values): | |
61 return argmax(enumerate(values)) | |
62 | |
63 # given an iterable of keys and a function f, return the key with largest f(key) | |
64 def argmax_f(keys, f): | |
65 return max(keys, key=f) | |
66 | |
67 def similar(a, b): | |
68 """ Returns string similarity between a and b """ | |
69 # https://stackoverflow.com/questions/17388213/find-the-similarity-metric-between-two-strings | |
70 return SequenceMatcher(None, a, b).ratio() | |
71 | |
72 | |
73 def get_resolved_clusters(coref): | |
74 """ Gets a coref object (from neural coref) and | |
75 returns the clusters as words """ | |
76 | |
77 mentions = coref.get_mentions() | |
78 clusters = coref.get_clusters()[0] | |
79 result = [] | |
80 for c in clusters: | |
81 result.append([mentions[r] for r in clusters[c]]) | |
82 return result | |
83 | |
84 def cluster_word(word, clusters): | |
85 """ Gets a word and a list of clusters of mentions | |
86 and figures out where the word matches most based on | |
87 string similarity """ | |
88 | |
89 similarities = [] | |
90 for rc in clusters: | |
91 similarity = [similar(word.lower(), c.text.lower()) for c in rc] | |
92 similarities.append(similarity) | |
93 max_similarities = [max(s) for s in similarities] | |
94 if max(max_similarities) > 0.75: | |
95 return argmax_index(max_similarities) | |
96 else: | |
97 return -1 | |
98 | |
99 def quotes2dict(text): | |
100 new_text = text | |
101 is_open = False | |
102 | |
103 quote_no = 0 | |
104 quote = [] | |
105 narrator = [] | |
106 quote_dict = {} | |
107 | |
108 for n, c in enumerate(text): | |
109 if c == '"' and not is_open: | |
110 is_open = True | |
111 quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator) | |
112 narrator = [] | |
113 quote_no += 1 | |
114 continue | |
115 | |
116 elif c == '"' and is_open: | |
117 is_open = False | |
118 quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote) | |
119 new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no)) | |
120 quote = [] | |
121 quote_no += 1 | |
122 continue | |
123 | |
124 if is_open: | |
125 quote.append(c) | |
126 elif not is_open: | |
127 narrator.append(c) | |
128 | |
129 return new_text, quote_dict | |
130 | |
131 def figure_gender(word, clusters, character_lut): | |
132 for c in character_lut: | |
133 if c.lower() in [w.lower() for w in word] and character_lut[c]['gender'] in ['Male', 'Female']: | |
134 return character_lut[c]['gender'] | |
135 | |
136 cluster_idx = cluster_word(word, clusters) | |
137 if cluster_idx == -1: | |
138 return 'neutral' | |
139 genders = [pronoun2gender(c.text) for c in clusters[cluster_idx]] | |
140 if 'Male' in genders and 'Female' not in 'genders': | |
141 return 'Male' | |
142 if 'Female' in genders and 'Male' not in 'genders': | |
143 return 'Female' | |
144 return 'neutral' | |
145 | |
146 def annotate(text, | |
147 ner_model, | |
148 rel_model, | |
149 character_lut, | |
150 saywords_lut, | |
151 spind_lut, | |
152 places_lut, | |
153 do_coreference_resolution=True): | |
154 """ | |
155 Function which annotates entities in text | |
156 using the model in "model", | |
157 | |
158 returns: A ner.Document object with tokens labelled via | |
159 the LUTS provided and also the NER model in "model" | |
160 """ | |
161 | |
162 # Find and store character lines in a dictionary | |
163 logging.info('Swapping character lines for character line tags') | |
164 processed_text, quotes = quotes2dict(text) | |
165 | |
166 # Create spacy document object from resulting text | |
167 # Create the nlp engine | |
168 logging.info("Loading 'en' spacy model") | |
169 nlp = spacy.load('en') | |
170 | |
171 # Loading coreference model | |
172 coref = Coref() | |
173 | |
174 | |
175 # Doing coreference resolution | |
176 if do_coreference_resolution: | |
177 logging.info("Doing one-shot coreference resolution (this might take some time)") | |
178 coref.one_shot_coref(processed_text) | |
179 resolved_clusters = get_resolved_clusters(coref) | |
180 processed_text = coref.get_resolved_utterances()[0] | |
181 | |
182 # Parse to spacy document | |
183 logging.info("Parsing document to spacy") | |
184 doc = nlp(processed_text) | |
185 | |
186 # Parse to our custom Document object | |
187 logging.info("Parsing document to our object format for Named Entity Recognition") | |
188 mDoc = ner.Document(doc) | |
189 | |
190 # Label <CLINE[0-9]+> as character line | |
191 logging.info("Labeling character lines") | |
192 spans = [r.span() for r in re.finditer(r'<cline[0-9]+>\.', mDoc.text)] | |
193 for span in spans: | |
194 mDoc.assign_label_to_tokens(span[0],span[1],'Character_Line') | |
195 | |
196 # Parse using LUTs | |
197 | |
198 # *- Characters | |
199 | |
200 # Sort by number of words so that tokens with more words override | |
201 # tokens with less words in labelling. For example if you have | |
202 # `man' and `an old man' as characters, the character labelled is going to | |
203 # be `an old man' and not the included `man'. | |
204 logging.info("Labeling characters from LUT") | |
205 cLUT = [c.lower() for c in sorted(character_lut, key=lambda x: len(x.split()))] | |
206 | |
207 # Find literals in document that match a character in cLUT | |
208 for c in cLUT: | |
209 spans = [r.span() for r in re.finditer(c, mDoc.text)] | |
210 for span in spans: | |
211 mDoc.assign_label_to_tokens(span[0],span[1],'Character') | |
212 | |
213 # *- Saywords | |
214 | |
215 # Assign labels to saywords. here saywords contain only one token. In addition | |
216 # we check against the saywords' lemma and not the saywords itself. | |
217 logging.info("Labeling saywords from LUT") | |
218 swLUT = [nlp(sw)[0].lemma_ for sw in saywords_lut] | |
219 for sw in swLUT: | |
220 mDoc.assign_label_to_tokens_by_matching_lemma(sw, 'Says') | |
221 | |
222 # *- Places | |
223 logging.info("Labeling places from LUT") | |
224 plLUT = [pl.lower() for pl in sorted(places_lut, key=lambda x: len(x.split()))] | |
225 | |
226 # Find literals in document that match a character in cLUT | |
227 for pl in plLUT: | |
228 spans = [r.span() for r in re.finditer(pl, mDoc.text)] | |
229 for span in spans: | |
230 mDoc.assign_label_to_tokens(span[0],span[1],'Place') | |
231 | |
232 # *- Spatial indicators | |
233 logging.info("Labeling spatial indicators from LUT") | |
234 spLUT = [sp.lower() for sp in sorted(spind_lut, key=lambda x: len(x.split()))] | |
235 for sp in spLUT: | |
236 spans = [r.span() for r in re.finditer(sp, mDoc.text)] | |
237 for span in spans: | |
238 mDoc.assign_label_to_tokens(span[0],span[1],'Spatial_Signal') | |
239 | |
240 logging.info("Extracting token features") | |
241 features, labels = mDoc.get_token_features_labels() | |
242 | |
243 logging.info("Predicting labels") | |
244 new_labels = ner_model.predict(features) | |
245 | |
246 | |
247 logging.info("Assigning labels based on the NER model") | |
248 # If a label is not already assigned by a LUT, assign it using the model | |
249 | |
250 #logging.info("{} {}".format(len(mDoc.tokens), len(new_labels))) | |
251 for m, sent in enumerate(mDoc.token_sentences): | |
252 for n, token in enumerate(sent): | |
253 if token.label == 'O': | |
254 token.label = new_labels[m][n] | |
255 | |
256 # Assign character labels | |
257 if do_coreference_resolution: | |
258 logging.info('Figuring out character genders') | |
259 character_tok_sent = mDoc.get_tokens_with_label('Character') | |
260 for sent in character_tok_sent: | |
261 for character in sent: | |
262 raw_string = " ".join([c.text for c in character]) | |
263 gender = figure_gender(raw_string, resolved_clusters, character_lut) | |
264 for tok in character: | |
265 if gender in ['Male', 'Female']: | |
266 tok.set_attribute('gender', gender) | |
267 | |
268 logging.info('Predicting the correct label for all possible relations in Document') | |
269 mDoc.predict_relations(rel_model) | |
270 | |
271 | |
272 return mDoc, quotes | |
273 | |
274 | |
275 def doc2brat(mDoc): | |
276 """ Returns a brat .ann file str based on mDoc """ | |
277 | |
278 # Dictionary that maps text span -> variable (to be used when | |
279 # adding relations ) | |
280 span2var = {} | |
281 | |
282 # Variable generator for entities (T in brat format) | |
283 tvar = ner.var_generator('T') | |
284 | |
285 # Variable generator for relations (E in brat format) | |
286 rvar = ner.var_generator('E') | |
287 | |
288 # Variable generator for attributions (E in brat format) | |
289 avar = ner.var_generator('A') | |
290 | |
291 ann_str = "" | |
292 # Extract characters in the format | |
293 # T1 Character START END character string | |
294 | |
295 labels = ['Character', 'Says', 'Place', 'Spatial_Signal', 'Character_Line'] | |
296 | |
297 for label in labels: | |
298 token_sentences = mDoc.get_tokens_with_label(label) | |
299 for tlist in token_sentences: | |
300 if len(tlist) == 0: | |
301 continue | |
302 | |
303 for tokens in tlist: | |
304 start = tokens[0].start | |
305 end = tokens[-1].end | |
306 txt = mDoc.text[start:end] | |
307 var = next(tvar) | |
308 ann_str += "{}\t{} {} {}\t{}\n".format(var, label, start, end, txt) | |
309 if 'gender' in tokens[0].attributes: | |
310 ann_str += "{}\t{} {} {}\n".format(next(avar), 'Gender', var, tokens[0].attributes['gender']) | |
311 | |
312 span2var[(start, end)] = var | |
313 | |
314 # Map relations | |
315 for r in mDoc.relations: | |
316 var = next(rvar) | |
317 trigger = r.trigger | |
318 trigger_label = trigger[0].label[2:] | |
319 trigger_start = trigger[0].start | |
320 trigger_end = trigger[-1].end | |
321 trigger_var = span2var[(trigger_start, trigger_end)] | |
322 | |
323 # If a trigger is Spatial_Signal then the | |
324 # arguments are of form Trajector and Landmark | |
325 | |
326 if trigger_label == 'Spatial_Signal': | |
327 arg1_label = 'Trajector' | |
328 arg2_label = 'Landmark' | |
329 | |
330 | |
331 # If a trigger is Says then the | |
332 # arguments are WHO and WHAT | |
333 | |
334 elif trigger_label == 'Says': | |
335 arg1_label = 'WHO' | |
336 arg2_label = 'WHAT' | |
337 | |
338 # Span for the first argument | |
339 arg1_start = r.arg1[0].start | |
340 arg1_end = r.arg1[-1].end | |
341 | |
342 # Variable for the first argument | |
343 arg1_var = span2var[(arg1_start, arg1_end)] | |
344 | |
345 # Span for the second argument | |
346 arg2_start = r.arg2[0].start | |
347 arg2_end = r.arg2[-1].end | |
348 | |
349 # Variable for the second argument | |
350 arg2_var = span2var[(arg2_start, arg2_end)] | |
351 | |
352 annot_line = "{}\t{}:{} {}:{} {}:{}\n".format(var, | |
353 trigger_label, | |
354 trigger_var, | |
355 arg1_label, | |
356 arg1_var, | |
357 arg2_label, | |
358 arg2_var) | |
359 | |
360 ann_str += annot_line | |
361 | |
362 | |
363 | |
364 | |
365 return ann_str | |
366 | |
367 if __name__=="__main__": | |
368 argparser = argparse.ArgumentParser() | |
369 argparser.add_argument('input_path', help='.txt file to parse') | |
370 argparser.add_argument('ner_model_path', help='.pkl file containing NER model') | |
371 argparser.add_argument('rel_model_path', help='.pkl file containing relational model') | |
372 argparser.add_argument('--say-lut', help='.txt file with list of saywords') | |
373 argparser.add_argument('--char-lut', help='.txt file with known characters') | |
374 argparser.add_argument('--place-lut', help='.txt file with known places') | |
375 argparser.add_argument('--spatial-indicator-lut', help='.txt file with known spatial indicators') | |
376 argparser.add_argument('--force', help='force overwrite when there is a file to be overwritten') | |
377 argparser.add_argument('--no-coreference-resolution', action='store_true', help='omit coreference resolution step') | |
378 | |
379 args = argparser.parse_args() | |
380 | |
381 # Load text file | |
382 with open(args.input_path) as f: | |
383 text = " ".join(f.read().split()) | |
384 | |
385 output_dir = os.path.dirname(args.input_path) | |
386 output_text_path = args.input_path[:-4] + '_processed.txt' | |
387 output_quotes_path = args.input_path[:-4] + '_quotes.json' | |
388 output_annotation_path = args.input_path[:-4] + '_processed.ann' | |
389 | |
390 # Load NER model file | |
391 ner_model = joblib.load(args.ner_model_path) | |
392 | |
393 # Load REL model file | |
394 rel_model = joblib.load(args.rel_model_path) | |
395 | |
396 # Load saywords | |
397 if args.say_lut: | |
398 saylut_path = args.say_lut | |
399 else: | |
400 saylut_path = 'saywords.txt' | |
401 | |
402 with open(saylut_path) as f: | |
403 saylut = [s for s in f.read().split('\n') if s.strip() != ''] | |
404 | |
405 # Load places LUT | |
406 if args.place_lut: | |
407 placelut_path = args.place_lut | |
408 else: | |
409 placelut_path = 'places.txt' | |
410 | |
411 with open(placelut_path) as f: | |
412 placelut = [s for s in f.read().split('\n') if s.strip() != ''] | |
413 | |
414 # Load spatial indicators LUT | |
415 if args.spatial_indicator_lut: | |
416 spatial_indicator_lut_path = args.spatial_indicator_lut | |
417 else: | |
418 spatial_indicator_lut_path = 'spatial_indicators.txt' | |
419 | |
420 with open(spatial_indicator_lut_path) as f: | |
421 spatial_indicator_lut = [s for s in f.read().split('\n') if s.strip() != ''] | |
422 | |
423 # Load character LUT | |
424 if args.char_lut: | |
425 charlut_path = args.char_lut | |
426 else: | |
427 charlut_path = 'characters.txt' | |
428 | |
429 with open(charlut_path) as f: | |
430 | |
431 charlist = [s for s in f.read().split('\n') if s.strip() != ''] # One character per line | |
432 | |
433 character_lut = {} # Stores character attributes indexed by name | |
434 for l in charlist: | |
435 name, attributes = l.split(':') | |
436 | |
437 gender = None | |
438 age = None | |
439 | |
440 for a in attributes.split(','): | |
441 if 'male' in a: | |
442 gender = a | |
443 elif a.lower() in ['young', 'old']: | |
444 age = a | |
445 | |
446 character_lut[name] = {} | |
447 if gender: | |
448 character_lut[name]['gender'] = gender | |
449 if age: | |
450 character_lut[name]['age'] = age | |
451 | |
452 if args.no_coreference_resolution: | |
453 corefres = False | |
454 else: | |
455 corefres = True | |
456 mDoc, quotes = annotate(text, ner_model, rel_model, character_lut, saylut, spatial_indicator_lut, placelut, corefres) | |
457 | |
458 annotation_text = doc2brat(mDoc) | |
459 | |
460 to_save = { | |
461 output_text_path: mDoc.text, | |
462 output_quotes_path: json.dumps(quotes), | |
463 output_annotation_path: annotation_text | |
464 } | |
465 | |
466 | |
467 for path in to_save: | |
468 if not os.path.exists(path) or args.force: | |
469 with open(path, 'w') as f: | |
470 f.write(to_save[path]) | |
471 else: | |
472 overwrite = input('Path {} exists, overwrite? (y/N) '.format(path)) | |
473 if overwrite[0] in ['Y', 'y']: | |
474 with open(path, 'w') as f: | |
475 f.write(to_save[path]) | |
476 | |
477 | |
478 | |
479 |