annotate demo/text2annotation.py @ 13:16066f0a7127 tip

fixed the problem with brat
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Sat, 08 Dec 2018 11:02:40 +0000
parents 90155bdd5dd6
children
rev   line source
e@0 1 #!/usr/bin/env python3
e@0 2 # -*- coding: utf-8 -*-
e@0 3 """
e@0 4 Created on Sat Apr 28 14:17:15 2018
e@0 5
e@0 6 @author: Emmanouil Theofanis Chourdakis
e@0 7
e@0 8 Takes a .txt story and annotates it based on:
e@0 9
e@0 10 characters,
e@0 11 places,
e@0 12 saywords,
e@0 13 character_lines,
e@0 14 spatial_indicators,
e@0 15
e@0 16 @output:
e@0 17 .ann file with the same name
e@0 18 .json file with the extracted character lines
e@0 19
e@0 20 """
e@0 21
e@0 22 import os
e@0 23 import argparse
e@0 24 from sklearn.externals import joblib
e@0 25 import ner
e@0 26 import spacy
e@0 27 import re
e@0 28 import logging
e@0 29 import json
e@0 30 from difflib import SequenceMatcher
e@0 31 from neuralcoref import Coref
e@0 32 from rel import *
e@0 33
e@0 34 def pronoun2gender(word):
e@0 35 pronoun2gender = {
e@0 36 'he' : 'Male',
e@0 37 'him': 'Male',
e@0 38 'she': 'Female',
e@0 39 'her': 'Female',
e@0 40 'his': 'Male',
e@0 41 'hers': 'Female',
e@0 42 'himself': 'Male',
e@0 43 'herself': 'Female',
e@0 44 }
e@0 45
e@0 46 if word in pronoun2gender:
e@0 47 return pronoun2gender[word]
e@0 48 else:
e@0 49 return 'neutral'
e@0 50
e@0 51
e@0 52 logging.basicConfig(level=logging.INFO)
e@0 53
e@0 54 # given an iterable of pairs return the key corresponding to the greatest value
e@0 55 def argmax(pairs):
e@0 56 #https://stackoverflow.com/questions/5098580/implementing-argmax-in-python
e@0 57 return max(pairs, key=lambda x: x[1])[0]
e@0 58
e@0 59 # given an iterable of values return the index of the greatest value
e@0 60 def argmax_index(values):
e@0 61 return argmax(enumerate(values))
e@0 62
e@0 63 # given an iterable of keys and a function f, return the key with largest f(key)
e@0 64 def argmax_f(keys, f):
e@0 65 return max(keys, key=f)
e@0 66
e@0 67 def similar(a, b):
e@0 68 """ Returns string similarity between a and b """
e@0 69 # https://stackoverflow.com/questions/17388213/find-the-similarity-metric-between-two-strings
e@0 70 return SequenceMatcher(None, a, b).ratio()
e@0 71
e@0 72
e@0 73 def get_resolved_clusters(coref):
e@0 74 """ Gets a coref object (from neural coref) and
e@0 75 returns the clusters as words """
e@0 76
e@0 77 mentions = coref.get_mentions()
e@0 78 clusters = coref.get_clusters()[0]
e@0 79 result = []
e@0 80 for c in clusters:
e@0 81 result.append([mentions[r] for r in clusters[c]])
e@0 82 return result
e@0 83
e@0 84 def cluster_word(word, clusters):
e@0 85 """ Gets a word and a list of clusters of mentions
e@0 86 and figures out where the word matches most based on
e@0 87 string similarity """
e@0 88
e@0 89 similarities = []
e@0 90 for rc in clusters:
e@0 91 similarity = [similar(word.lower(), c.text.lower()) for c in rc]
e@0 92 similarities.append(similarity)
e@0 93 max_similarities = [max(s) for s in similarities]
e@0 94 if max(max_similarities) > 0.75:
e@0 95 return argmax_index(max_similarities)
e@0 96 else:
e@0 97 return -1
e@0 98
e@0 99 def quotes2dict(text):
e@0 100 new_text = text
e@0 101 is_open = False
e@0 102
e@0 103 quote_no = 0
e@0 104 quote = []
e@0 105 narrator = []
e@0 106 quote_dict = {}
e@0 107
e@0 108 for n, c in enumerate(text):
e@0 109 if c == '"' and not is_open:
e@0 110 is_open = True
e@0 111 quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
e@0 112 narrator = []
e@0 113 quote_no += 1
e@0 114 continue
e@0 115
e@0 116 elif c == '"' and is_open:
e@0 117 is_open = False
e@0 118 quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
e@0 119 new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
e@0 120 quote = []
e@0 121 quote_no += 1
e@0 122 continue
e@0 123
e@0 124 if is_open:
e@0 125 quote.append(c)
e@0 126 elif not is_open:
e@0 127 narrator.append(c)
e@0 128
e@0 129 return new_text, quote_dict
e@0 130
e@0 131 def figure_gender(word, clusters, character_lut):
e@0 132 for c in character_lut:
e@0 133 if c.lower() in [w.lower() for w in word] and character_lut[c]['gender'] in ['Male', 'Female']:
e@0 134 return character_lut[c]['gender']
e@0 135
e@0 136 cluster_idx = cluster_word(word, clusters)
e@0 137 if cluster_idx == -1:
e@0 138 return 'neutral'
e@0 139 genders = [pronoun2gender(c.text) for c in clusters[cluster_idx]]
e@0 140 if 'Male' in genders and 'Female' not in 'genders':
e@0 141 return 'Male'
e@0 142 if 'Female' in genders and 'Male' not in 'genders':
e@0 143 return 'Female'
e@0 144 return 'neutral'
e@0 145
e@0 146 def annotate(text,
e@0 147 ner_model,
e@0 148 rel_model,
e@0 149 character_lut,
e@0 150 saywords_lut,
e@0 151 spind_lut,
e@0 152 places_lut,
e@0 153 do_coreference_resolution=True):
e@0 154 """
e@0 155 Function which annotates entities in text
e@0 156 using the model in "model",
e@0 157
e@0 158 returns: A ner.Document object with tokens labelled via
e@0 159 the LUTS provided and also the NER model in "model"
e@0 160 """
e@0 161
e@0 162 # Find and store character lines in a dictionary
e@0 163 logging.info('Swapping character lines for character line tags')
e@0 164 processed_text, quotes = quotes2dict(text)
e@0 165
e@0 166 # Create spacy document object from resulting text
e@0 167 # Create the nlp engine
e@0 168 logging.info("Loading 'en' spacy model")
e@0 169 nlp = spacy.load('en')
e@0 170
e@0 171 # Loading coreference model
e@0 172 coref = Coref()
e@0 173
e@0 174
e@0 175 # Doing coreference resolution
e@0 176 if do_coreference_resolution:
e@0 177 logging.info("Doing one-shot coreference resolution (this might take some time)")
e@0 178 coref.one_shot_coref(processed_text)
e@0 179 resolved_clusters = get_resolved_clusters(coref)
e@0 180 processed_text = coref.get_resolved_utterances()[0]
e@0 181
e@0 182 # Parse to spacy document
e@0 183 logging.info("Parsing document to spacy")
e@0 184 doc = nlp(processed_text)
e@0 185
e@0 186 # Parse to our custom Document object
e@0 187 logging.info("Parsing document to our object format for Named Entity Recognition")
e@0 188 mDoc = ner.Document(doc)
e@0 189
e@0 190 # Label <CLINE[0-9]+> as character line
e@0 191 logging.info("Labeling character lines")
e@0 192 spans = [r.span() for r in re.finditer(r'<cline[0-9]+>\.', mDoc.text)]
e@0 193 for span in spans:
e@0 194 mDoc.assign_label_to_tokens(span[0],span[1],'Character_Line')
e@0 195
e@0 196 # Parse using LUTs
e@0 197
e@0 198 # *- Characters
e@0 199
e@0 200 # Sort by number of words so that tokens with more words override
e@0 201 # tokens with less words in labelling. For example if you have
e@0 202 # `man' and `an old man' as characters, the character labelled is going to
e@0 203 # be `an old man' and not the included `man'.
e@0 204 logging.info("Labeling characters from LUT")
e@0 205 cLUT = [c.lower() for c in sorted(character_lut, key=lambda x: len(x.split()))]
e@0 206
e@0 207 # Find literals in document that match a character in cLUT
e@0 208 for c in cLUT:
e@0 209 spans = [r.span() for r in re.finditer(c, mDoc.text)]
e@0 210 for span in spans:
e@0 211 mDoc.assign_label_to_tokens(span[0],span[1],'Character')
e@0 212
e@0 213 # *- Saywords
e@0 214
e@0 215 # Assign labels to saywords. here saywords contain only one token. In addition
e@0 216 # we check against the saywords' lemma and not the saywords itself.
e@0 217 logging.info("Labeling saywords from LUT")
e@0 218 swLUT = [nlp(sw)[0].lemma_ for sw in saywords_lut]
e@0 219 for sw in swLUT:
e@0 220 mDoc.assign_label_to_tokens_by_matching_lemma(sw, 'Says')
e@0 221
e@0 222 # *- Places
e@0 223 logging.info("Labeling places from LUT")
e@0 224 plLUT = [pl.lower() for pl in sorted(places_lut, key=lambda x: len(x.split()))]
e@0 225
e@0 226 # Find literals in document that match a character in cLUT
e@0 227 for pl in plLUT:
e@0 228 spans = [r.span() for r in re.finditer(pl, mDoc.text)]
e@0 229 for span in spans:
e@0 230 mDoc.assign_label_to_tokens(span[0],span[1],'Place')
e@0 231
e@0 232 # *- Spatial indicators
e@0 233 logging.info("Labeling spatial indicators from LUT")
e@0 234 spLUT = [sp.lower() for sp in sorted(spind_lut, key=lambda x: len(x.split()))]
e@0 235 for sp in spLUT:
e@0 236 spans = [r.span() for r in re.finditer(sp, mDoc.text)]
e@0 237 for span in spans:
e@0 238 mDoc.assign_label_to_tokens(span[0],span[1],'Spatial_Signal')
e@0 239
e@0 240 logging.info("Extracting token features")
e@0 241 features, labels = mDoc.get_token_features_labels()
e@0 242
e@0 243 logging.info("Predicting labels")
e@0 244 new_labels = ner_model.predict(features)
e@0 245
e@0 246
e@0 247 logging.info("Assigning labels based on the NER model")
e@0 248 # If a label is not already assigned by a LUT, assign it using the model
e@0 249
e@0 250 #logging.info("{} {}".format(len(mDoc.tokens), len(new_labels)))
e@0 251 for m, sent in enumerate(mDoc.token_sentences):
e@0 252 for n, token in enumerate(sent):
e@0 253 if token.label == 'O':
e@0 254 token.label = new_labels[m][n]
e@0 255
e@0 256 # Assign character labels
e@0 257 if do_coreference_resolution:
e@0 258 logging.info('Figuring out character genders')
e@0 259 character_tok_sent = mDoc.get_tokens_with_label('Character')
e@0 260 for sent in character_tok_sent:
e@0 261 for character in sent:
e@0 262 raw_string = " ".join([c.text for c in character])
e@0 263 gender = figure_gender(raw_string, resolved_clusters, character_lut)
e@0 264 for tok in character:
e@0 265 if gender in ['Male', 'Female']:
e@0 266 tok.set_attribute('gender', gender)
e@0 267
e@0 268 logging.info('Predicting the correct label for all possible relations in Document')
e@0 269 mDoc.predict_relations(rel_model)
e@0 270
e@0 271
e@0 272 return mDoc, quotes
e@0 273
e@0 274
e@0 275 def doc2brat(mDoc):
e@0 276 """ Returns a brat .ann file str based on mDoc """
e@0 277
e@0 278 # Dictionary that maps text span -> variable (to be used when
e@0 279 # adding relations )
e@0 280 span2var = {}
e@0 281
e@0 282 # Variable generator for entities (T in brat format)
e@0 283 tvar = ner.var_generator('T')
e@0 284
e@0 285 # Variable generator for relations (E in brat format)
e@0 286 rvar = ner.var_generator('E')
e@0 287
e@0 288 # Variable generator for attributions (E in brat format)
e@0 289 avar = ner.var_generator('A')
e@0 290
e@0 291 ann_str = ""
e@0 292 # Extract characters in the format
e@0 293 # T1 Character START END character string
e@0 294
e@0 295 labels = ['Character', 'Says', 'Place', 'Spatial_Signal', 'Character_Line']
e@0 296
e@0 297 for label in labels:
e@0 298 token_sentences = mDoc.get_tokens_with_label(label)
e@0 299 for tlist in token_sentences:
e@0 300 if len(tlist) == 0:
e@0 301 continue
e@0 302
e@0 303 for tokens in tlist:
e@0 304 start = tokens[0].start
e@0 305 end = tokens[-1].end
e@0 306 txt = mDoc.text[start:end]
e@0 307 var = next(tvar)
e@0 308 ann_str += "{}\t{} {} {}\t{}\n".format(var, label, start, end, txt)
e@0 309 if 'gender' in tokens[0].attributes:
e@0 310 ann_str += "{}\t{} {} {}\n".format(next(avar), 'Gender', var, tokens[0].attributes['gender'])
e@0 311
e@0 312 span2var[(start, end)] = var
e@0 313
e@0 314 # Map relations
e@0 315 for r in mDoc.relations:
e@0 316 var = next(rvar)
e@0 317 trigger = r.trigger
e@0 318 trigger_label = trigger[0].label[2:]
e@0 319 trigger_start = trigger[0].start
e@0 320 trigger_end = trigger[-1].end
e@0 321 trigger_var = span2var[(trigger_start, trigger_end)]
e@0 322
e@0 323 # If a trigger is Spatial_Signal then the
e@0 324 # arguments are of form Trajector and Landmark
e@0 325
e@0 326 if trigger_label == 'Spatial_Signal':
e@0 327 arg1_label = 'Trajector'
e@0 328 arg2_label = 'Landmark'
e@0 329
e@0 330
e@0 331 # If a trigger is Says then the
e@0 332 # arguments are WHO and WHAT
e@0 333
e@0 334 elif trigger_label == 'Says':
e@0 335 arg1_label = 'WHO'
e@0 336 arg2_label = 'WHAT'
e@0 337
e@0 338 # Span for the first argument
e@0 339 arg1_start = r.arg1[0].start
e@0 340 arg1_end = r.arg1[-1].end
e@0 341
e@0 342 # Variable for the first argument
e@0 343 arg1_var = span2var[(arg1_start, arg1_end)]
e@0 344
e@0 345 # Span for the second argument
e@0 346 arg2_start = r.arg2[0].start
e@0 347 arg2_end = r.arg2[-1].end
e@0 348
e@0 349 # Variable for the second argument
e@0 350 arg2_var = span2var[(arg2_start, arg2_end)]
e@0 351
e@0 352 annot_line = "{}\t{}:{} {}:{} {}:{}\n".format(var,
e@0 353 trigger_label,
e@0 354 trigger_var,
e@0 355 arg1_label,
e@0 356 arg1_var,
e@0 357 arg2_label,
e@0 358 arg2_var)
e@0 359
e@0 360 ann_str += annot_line
e@0 361
e@0 362
e@0 363
e@0 364
e@0 365 return ann_str
e@0 366
e@0 367 if __name__=="__main__":
e@0 368 argparser = argparse.ArgumentParser()
e@0 369 argparser.add_argument('input_path', help='.txt file to parse')
e@0 370 argparser.add_argument('ner_model_path', help='.pkl file containing NER model')
e@0 371 argparser.add_argument('rel_model_path', help='.pkl file containing relational model')
e@0 372 argparser.add_argument('--say-lut', help='.txt file with list of saywords')
e@0 373 argparser.add_argument('--char-lut', help='.txt file with known characters')
e@0 374 argparser.add_argument('--place-lut', help='.txt file with known places')
e@0 375 argparser.add_argument('--spatial-indicator-lut', help='.txt file with known spatial indicators')
e@0 376 argparser.add_argument('--force', help='force overwrite when there is a file to be overwritten')
e@0 377 argparser.add_argument('--no-coreference-resolution', action='store_true', help='omit coreference resolution step')
e@0 378
e@0 379 args = argparser.parse_args()
e@0 380
e@0 381 # Load text file
e@0 382 with open(args.input_path) as f:
e@0 383 text = " ".join(f.read().split())
e@0 384
e@0 385 output_dir = os.path.dirname(args.input_path)
e@0 386 output_text_path = args.input_path[:-4] + '_processed.txt'
e@0 387 output_quotes_path = args.input_path[:-4] + '_quotes.json'
e@0 388 output_annotation_path = args.input_path[:-4] + '_processed.ann'
e@0 389
e@0 390 # Load NER model file
e@0 391 ner_model = joblib.load(args.ner_model_path)
e@0 392
e@0 393 # Load REL model file
e@0 394 rel_model = joblib.load(args.rel_model_path)
e@0 395
e@0 396 # Load saywords
e@0 397 if args.say_lut:
e@0 398 saylut_path = args.say_lut
e@0 399 else:
e@0 400 saylut_path = 'saywords.txt'
e@0 401
e@0 402 with open(saylut_path) as f:
e@0 403 saylut = [s for s in f.read().split('\n') if s.strip() != '']
e@0 404
e@0 405 # Load places LUT
e@0 406 if args.place_lut:
e@0 407 placelut_path = args.place_lut
e@0 408 else:
e@0 409 placelut_path = 'places.txt'
e@0 410
e@0 411 with open(placelut_path) as f:
e@0 412 placelut = [s for s in f.read().split('\n') if s.strip() != '']
e@0 413
e@0 414 # Load spatial indicators LUT
e@0 415 if args.spatial_indicator_lut:
e@0 416 spatial_indicator_lut_path = args.spatial_indicator_lut
e@0 417 else:
e@0 418 spatial_indicator_lut_path = 'spatial_indicators.txt'
e@0 419
e@0 420 with open(spatial_indicator_lut_path) as f:
e@0 421 spatial_indicator_lut = [s for s in f.read().split('\n') if s.strip() != '']
e@0 422
e@0 423 # Load character LUT
e@0 424 if args.char_lut:
e@0 425 charlut_path = args.char_lut
e@0 426 else:
e@0 427 charlut_path = 'characters.txt'
e@0 428
e@0 429 with open(charlut_path) as f:
e@0 430
e@0 431 charlist = [s for s in f.read().split('\n') if s.strip() != ''] # One character per line
e@0 432
e@0 433 character_lut = {} # Stores character attributes indexed by name
e@0 434 for l in charlist:
e@0 435 name, attributes = l.split(':')
e@0 436
e@0 437 gender = None
e@0 438 age = None
e@0 439
e@0 440 for a in attributes.split(','):
e@0 441 if 'male' in a:
e@0 442 gender = a
e@0 443 elif a.lower() in ['young', 'old']:
e@0 444 age = a
e@0 445
e@0 446 character_lut[name] = {}
e@0 447 if gender:
e@0 448 character_lut[name]['gender'] = gender
e@0 449 if age:
e@0 450 character_lut[name]['age'] = age
e@0 451
e@0 452 if args.no_coreference_resolution:
e@0 453 corefres = False
e@0 454 else:
e@0 455 corefres = True
e@0 456 mDoc, quotes = annotate(text, ner_model, rel_model, character_lut, saylut, spatial_indicator_lut, placelut, corefres)
e@0 457
e@0 458 annotation_text = doc2brat(mDoc)
e@0 459
e@0 460 to_save = {
e@0 461 output_text_path: mDoc.text,
e@0 462 output_quotes_path: json.dumps(quotes),
e@0 463 output_annotation_path: annotation_text
e@0 464 }
e@0 465
e@0 466
e@0 467 for path in to_save:
e@0 468 if not os.path.exists(path) or args.force:
e@0 469 with open(path, 'w') as f:
e@0 470 f.write(to_save[path])
e@0 471 else:
e@0 472 overwrite = input('Path {} exists, overwrite? (y/N) '.format(path))
e@0 473 if overwrite[0] in ['Y', 'y']:
e@0 474 with open(path, 'w') as f:
e@0 475 f.write(to_save[path])
e@0 476
e@0 477
e@0 478
e@0 479