comparison demo/text2annotation.py @ 0:90155bdd5dd6

first commit
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Wed, 16 May 2018 18:27:05 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:90155bdd5dd6
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 """
4 Created on Sat Apr 28 14:17:15 2018
5
6 @author: Emmanouil Theofanis Chourdakis
7
8 Takes a .txt story and annotates it based on:
9
10 characters,
11 places,
12 saywords,
13 character_lines,
14 spatial_indicators,
15
16 @output:
17 .ann file with the same name
18 .json file with the extracted character lines
19
20 """
21
22 import os
23 import argparse
24 from sklearn.externals import joblib
25 import ner
26 import spacy
27 import re
28 import logging
29 import json
30 from difflib import SequenceMatcher
31 from neuralcoref import Coref
32 from rel import *
33
34 def pronoun2gender(word):
35 pronoun2gender = {
36 'he' : 'Male',
37 'him': 'Male',
38 'she': 'Female',
39 'her': 'Female',
40 'his': 'Male',
41 'hers': 'Female',
42 'himself': 'Male',
43 'herself': 'Female',
44 }
45
46 if word in pronoun2gender:
47 return pronoun2gender[word]
48 else:
49 return 'neutral'
50
51
52 logging.basicConfig(level=logging.INFO)
53
54 # given an iterable of pairs return the key corresponding to the greatest value
55 def argmax(pairs):
56 #https://stackoverflow.com/questions/5098580/implementing-argmax-in-python
57 return max(pairs, key=lambda x: x[1])[0]
58
59 # given an iterable of values return the index of the greatest value
60 def argmax_index(values):
61 return argmax(enumerate(values))
62
63 # given an iterable of keys and a function f, return the key with largest f(key)
64 def argmax_f(keys, f):
65 return max(keys, key=f)
66
67 def similar(a, b):
68 """ Returns string similarity between a and b """
69 # https://stackoverflow.com/questions/17388213/find-the-similarity-metric-between-two-strings
70 return SequenceMatcher(None, a, b).ratio()
71
72
73 def get_resolved_clusters(coref):
74 """ Gets a coref object (from neural coref) and
75 returns the clusters as words """
76
77 mentions = coref.get_mentions()
78 clusters = coref.get_clusters()[0]
79 result = []
80 for c in clusters:
81 result.append([mentions[r] for r in clusters[c]])
82 return result
83
84 def cluster_word(word, clusters):
85 """ Gets a word and a list of clusters of mentions
86 and figures out where the word matches most based on
87 string similarity """
88
89 similarities = []
90 for rc in clusters:
91 similarity = [similar(word.lower(), c.text.lower()) for c in rc]
92 similarities.append(similarity)
93 max_similarities = [max(s) for s in similarities]
94 if max(max_similarities) > 0.75:
95 return argmax_index(max_similarities)
96 else:
97 return -1
98
99 def quotes2dict(text):
100 new_text = text
101 is_open = False
102
103 quote_no = 0
104 quote = []
105 narrator = []
106 quote_dict = {}
107
108 for n, c in enumerate(text):
109 if c == '"' and not is_open:
110 is_open = True
111 quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
112 narrator = []
113 quote_no += 1
114 continue
115
116 elif c == '"' and is_open:
117 is_open = False
118 quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
119 new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
120 quote = []
121 quote_no += 1
122 continue
123
124 if is_open:
125 quote.append(c)
126 elif not is_open:
127 narrator.append(c)
128
129 return new_text, quote_dict
130
131 def figure_gender(word, clusters, character_lut):
132 for c in character_lut:
133 if c.lower() in [w.lower() for w in word] and character_lut[c]['gender'] in ['Male', 'Female']:
134 return character_lut[c]['gender']
135
136 cluster_idx = cluster_word(word, clusters)
137 if cluster_idx == -1:
138 return 'neutral'
139 genders = [pronoun2gender(c.text) for c in clusters[cluster_idx]]
140 if 'Male' in genders and 'Female' not in 'genders':
141 return 'Male'
142 if 'Female' in genders and 'Male' not in 'genders':
143 return 'Female'
144 return 'neutral'
145
146 def annotate(text,
147 ner_model,
148 rel_model,
149 character_lut,
150 saywords_lut,
151 spind_lut,
152 places_lut,
153 do_coreference_resolution=True):
154 """
155 Function which annotates entities in text
156 using the model in "model",
157
158 returns: A ner.Document object with tokens labelled via
159 the LUTS provided and also the NER model in "model"
160 """
161
162 # Find and store character lines in a dictionary
163 logging.info('Swapping character lines for character line tags')
164 processed_text, quotes = quotes2dict(text)
165
166 # Create spacy document object from resulting text
167 # Create the nlp engine
168 logging.info("Loading 'en' spacy model")
169 nlp = spacy.load('en')
170
171 # Loading coreference model
172 coref = Coref()
173
174
175 # Doing coreference resolution
176 if do_coreference_resolution:
177 logging.info("Doing one-shot coreference resolution (this might take some time)")
178 coref.one_shot_coref(processed_text)
179 resolved_clusters = get_resolved_clusters(coref)
180 processed_text = coref.get_resolved_utterances()[0]
181
182 # Parse to spacy document
183 logging.info("Parsing document to spacy")
184 doc = nlp(processed_text)
185
186 # Parse to our custom Document object
187 logging.info("Parsing document to our object format for Named Entity Recognition")
188 mDoc = ner.Document(doc)
189
190 # Label <CLINE[0-9]+> as character line
191 logging.info("Labeling character lines")
192 spans = [r.span() for r in re.finditer(r'<cline[0-9]+>\.', mDoc.text)]
193 for span in spans:
194 mDoc.assign_label_to_tokens(span[0],span[1],'Character_Line')
195
196 # Parse using LUTs
197
198 # *- Characters
199
200 # Sort by number of words so that tokens with more words override
201 # tokens with less words in labelling. For example if you have
202 # `man' and `an old man' as characters, the character labelled is going to
203 # be `an old man' and not the included `man'.
204 logging.info("Labeling characters from LUT")
205 cLUT = [c.lower() for c in sorted(character_lut, key=lambda x: len(x.split()))]
206
207 # Find literals in document that match a character in cLUT
208 for c in cLUT:
209 spans = [r.span() for r in re.finditer(c, mDoc.text)]
210 for span in spans:
211 mDoc.assign_label_to_tokens(span[0],span[1],'Character')
212
213 # *- Saywords
214
215 # Assign labels to saywords. here saywords contain only one token. In addition
216 # we check against the saywords' lemma and not the saywords itself.
217 logging.info("Labeling saywords from LUT")
218 swLUT = [nlp(sw)[0].lemma_ for sw in saywords_lut]
219 for sw in swLUT:
220 mDoc.assign_label_to_tokens_by_matching_lemma(sw, 'Says')
221
222 # *- Places
223 logging.info("Labeling places from LUT")
224 plLUT = [pl.lower() for pl in sorted(places_lut, key=lambda x: len(x.split()))]
225
226 # Find literals in document that match a character in cLUT
227 for pl in plLUT:
228 spans = [r.span() for r in re.finditer(pl, mDoc.text)]
229 for span in spans:
230 mDoc.assign_label_to_tokens(span[0],span[1],'Place')
231
232 # *- Spatial indicators
233 logging.info("Labeling spatial indicators from LUT")
234 spLUT = [sp.lower() for sp in sorted(spind_lut, key=lambda x: len(x.split()))]
235 for sp in spLUT:
236 spans = [r.span() for r in re.finditer(sp, mDoc.text)]
237 for span in spans:
238 mDoc.assign_label_to_tokens(span[0],span[1],'Spatial_Signal')
239
240 logging.info("Extracting token features")
241 features, labels = mDoc.get_token_features_labels()
242
243 logging.info("Predicting labels")
244 new_labels = ner_model.predict(features)
245
246
247 logging.info("Assigning labels based on the NER model")
248 # If a label is not already assigned by a LUT, assign it using the model
249
250 #logging.info("{} {}".format(len(mDoc.tokens), len(new_labels)))
251 for m, sent in enumerate(mDoc.token_sentences):
252 for n, token in enumerate(sent):
253 if token.label == 'O':
254 token.label = new_labels[m][n]
255
256 # Assign character labels
257 if do_coreference_resolution:
258 logging.info('Figuring out character genders')
259 character_tok_sent = mDoc.get_tokens_with_label('Character')
260 for sent in character_tok_sent:
261 for character in sent:
262 raw_string = " ".join([c.text for c in character])
263 gender = figure_gender(raw_string, resolved_clusters, character_lut)
264 for tok in character:
265 if gender in ['Male', 'Female']:
266 tok.set_attribute('gender', gender)
267
268 logging.info('Predicting the correct label for all possible relations in Document')
269 mDoc.predict_relations(rel_model)
270
271
272 return mDoc, quotes
273
274
275 def doc2brat(mDoc):
276 """ Returns a brat .ann file str based on mDoc """
277
278 # Dictionary that maps text span -> variable (to be used when
279 # adding relations )
280 span2var = {}
281
282 # Variable generator for entities (T in brat format)
283 tvar = ner.var_generator('T')
284
285 # Variable generator for relations (E in brat format)
286 rvar = ner.var_generator('E')
287
288 # Variable generator for attributions (E in brat format)
289 avar = ner.var_generator('A')
290
291 ann_str = ""
292 # Extract characters in the format
293 # T1 Character START END character string
294
295 labels = ['Character', 'Says', 'Place', 'Spatial_Signal', 'Character_Line']
296
297 for label in labels:
298 token_sentences = mDoc.get_tokens_with_label(label)
299 for tlist in token_sentences:
300 if len(tlist) == 0:
301 continue
302
303 for tokens in tlist:
304 start = tokens[0].start
305 end = tokens[-1].end
306 txt = mDoc.text[start:end]
307 var = next(tvar)
308 ann_str += "{}\t{} {} {}\t{}\n".format(var, label, start, end, txt)
309 if 'gender' in tokens[0].attributes:
310 ann_str += "{}\t{} {} {}\n".format(next(avar), 'Gender', var, tokens[0].attributes['gender'])
311
312 span2var[(start, end)] = var
313
314 # Map relations
315 for r in mDoc.relations:
316 var = next(rvar)
317 trigger = r.trigger
318 trigger_label = trigger[0].label[2:]
319 trigger_start = trigger[0].start
320 trigger_end = trigger[-1].end
321 trigger_var = span2var[(trigger_start, trigger_end)]
322
323 # If a trigger is Spatial_Signal then the
324 # arguments are of form Trajector and Landmark
325
326 if trigger_label == 'Spatial_Signal':
327 arg1_label = 'Trajector'
328 arg2_label = 'Landmark'
329
330
331 # If a trigger is Says then the
332 # arguments are WHO and WHAT
333
334 elif trigger_label == 'Says':
335 arg1_label = 'WHO'
336 arg2_label = 'WHAT'
337
338 # Span for the first argument
339 arg1_start = r.arg1[0].start
340 arg1_end = r.arg1[-1].end
341
342 # Variable for the first argument
343 arg1_var = span2var[(arg1_start, arg1_end)]
344
345 # Span for the second argument
346 arg2_start = r.arg2[0].start
347 arg2_end = r.arg2[-1].end
348
349 # Variable for the second argument
350 arg2_var = span2var[(arg2_start, arg2_end)]
351
352 annot_line = "{}\t{}:{} {}:{} {}:{}\n".format(var,
353 trigger_label,
354 trigger_var,
355 arg1_label,
356 arg1_var,
357 arg2_label,
358 arg2_var)
359
360 ann_str += annot_line
361
362
363
364
365 return ann_str
366
367 if __name__=="__main__":
368 argparser = argparse.ArgumentParser()
369 argparser.add_argument('input_path', help='.txt file to parse')
370 argparser.add_argument('ner_model_path', help='.pkl file containing NER model')
371 argparser.add_argument('rel_model_path', help='.pkl file containing relational model')
372 argparser.add_argument('--say-lut', help='.txt file with list of saywords')
373 argparser.add_argument('--char-lut', help='.txt file with known characters')
374 argparser.add_argument('--place-lut', help='.txt file with known places')
375 argparser.add_argument('--spatial-indicator-lut', help='.txt file with known spatial indicators')
376 argparser.add_argument('--force', help='force overwrite when there is a file to be overwritten')
377 argparser.add_argument('--no-coreference-resolution', action='store_true', help='omit coreference resolution step')
378
379 args = argparser.parse_args()
380
381 # Load text file
382 with open(args.input_path) as f:
383 text = " ".join(f.read().split())
384
385 output_dir = os.path.dirname(args.input_path)
386 output_text_path = args.input_path[:-4] + '_processed.txt'
387 output_quotes_path = args.input_path[:-4] + '_quotes.json'
388 output_annotation_path = args.input_path[:-4] + '_processed.ann'
389
390 # Load NER model file
391 ner_model = joblib.load(args.ner_model_path)
392
393 # Load REL model file
394 rel_model = joblib.load(args.rel_model_path)
395
396 # Load saywords
397 if args.say_lut:
398 saylut_path = args.say_lut
399 else:
400 saylut_path = 'saywords.txt'
401
402 with open(saylut_path) as f:
403 saylut = [s for s in f.read().split('\n') if s.strip() != '']
404
405 # Load places LUT
406 if args.place_lut:
407 placelut_path = args.place_lut
408 else:
409 placelut_path = 'places.txt'
410
411 with open(placelut_path) as f:
412 placelut = [s for s in f.read().split('\n') if s.strip() != '']
413
414 # Load spatial indicators LUT
415 if args.spatial_indicator_lut:
416 spatial_indicator_lut_path = args.spatial_indicator_lut
417 else:
418 spatial_indicator_lut_path = 'spatial_indicators.txt'
419
420 with open(spatial_indicator_lut_path) as f:
421 spatial_indicator_lut = [s for s in f.read().split('\n') if s.strip() != '']
422
423 # Load character LUT
424 if args.char_lut:
425 charlut_path = args.char_lut
426 else:
427 charlut_path = 'characters.txt'
428
429 with open(charlut_path) as f:
430
431 charlist = [s for s in f.read().split('\n') if s.strip() != ''] # One character per line
432
433 character_lut = {} # Stores character attributes indexed by name
434 for l in charlist:
435 name, attributes = l.split(':')
436
437 gender = None
438 age = None
439
440 for a in attributes.split(','):
441 if 'male' in a:
442 gender = a
443 elif a.lower() in ['young', 'old']:
444 age = a
445
446 character_lut[name] = {}
447 if gender:
448 character_lut[name]['gender'] = gender
449 if age:
450 character_lut[name]['age'] = age
451
452 if args.no_coreference_resolution:
453 corefres = False
454 else:
455 corefres = True
456 mDoc, quotes = annotate(text, ner_model, rel_model, character_lut, saylut, spatial_indicator_lut, placelut, corefres)
457
458 annotation_text = doc2brat(mDoc)
459
460 to_save = {
461 output_text_path: mDoc.text,
462 output_quotes_path: json.dumps(quotes),
463 output_annotation_path: annotation_text
464 }
465
466
467 for path in to_save:
468 if not os.path.exists(path) or args.force:
469 with open(path, 'w') as f:
470 f.write(to_save[path])
471 else:
472 overwrite = input('Path {} exists, overwrite? (y/N) '.format(path))
473 if overwrite[0] in ['Y', 'y']:
474 with open(path, 'w') as f:
475 f.write(to_save[path])
476
477
478
479