Mercurial > hg > from-my-pen-to-your-ears-supplementary-material
diff demo/text2annotation.py @ 0:4dad87badb0c
initial commit
author | Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk> |
---|---|
date | Wed, 16 May 2018 17:56:10 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/demo/text2annotation.py Wed May 16 17:56:10 2018 +0100 @@ -0,0 +1,479 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sat Apr 28 14:17:15 2018 + +@author: Emmanouil Theofanis Chourdakis + +Takes a .txt story and annotates it based on: + + characters, + places, + saywords, + character_lines, + spatial_indicators, + +@output: + .ann file with the same name + .json file with the extracted character lines + +""" + +import os +import argparse +from sklearn.externals import joblib +import ner +import spacy +import re +import logging +import json +from difflib import SequenceMatcher +from neuralcoref import Coref +from rel import * + +def pronoun2gender(word): + pronoun2gender = { + 'he' : 'Male', + 'him': 'Male', + 'she': 'Female', + 'her': 'Female', + 'his': 'Male', + 'hers': 'Female', + 'himself': 'Male', + 'herself': 'Female', + } + + if word in pronoun2gender: + return pronoun2gender[word] + else: + return 'neutral' + + +logging.basicConfig(level=logging.INFO) + +# given an iterable of pairs return the key corresponding to the greatest value +def argmax(pairs): + #https://stackoverflow.com/questions/5098580/implementing-argmax-in-python + return max(pairs, key=lambda x: x[1])[0] + +# given an iterable of values return the index of the greatest value +def argmax_index(values): + return argmax(enumerate(values)) + +# given an iterable of keys and a function f, return the key with largest f(key) +def argmax_f(keys, f): + return max(keys, key=f) + +def similar(a, b): + """ Returns string similarity between a and b """ + # https://stackoverflow.com/questions/17388213/find-the-similarity-metric-between-two-strings + return SequenceMatcher(None, a, b).ratio() + + +def get_resolved_clusters(coref): + """ Gets a coref object (from neural coref) and + returns the clusters as words """ + + mentions = coref.get_mentions() + clusters = coref.get_clusters()[0] + result = [] + for c in clusters: + result.append([mentions[r] for r in clusters[c]]) + return result + +def cluster_word(word, clusters): + """ Gets a word and a list of clusters of mentions + and figures out where the word matches most based on + string similarity """ + + similarities = [] + for rc in clusters: + similarity = [similar(word.lower(), c.text.lower()) for c in rc] + similarities.append(similarity) + max_similarities = [max(s) for s in similarities] + if max(max_similarities) > 0.75: + return argmax_index(max_similarities) + else: + return -1 + +def quotes2dict(text): + new_text = text + is_open = False + + quote_no = 0 + quote = [] + narrator = [] + quote_dict = {} + + for n, c in enumerate(text): + if c == '"' and not is_open: + is_open = True + quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator) + narrator = [] + quote_no += 1 + continue + + elif c == '"' and is_open: + is_open = False + quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote) + new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no)) + quote = [] + quote_no += 1 + continue + + if is_open: + quote.append(c) + elif not is_open: + narrator.append(c) + + return new_text, quote_dict + +def figure_gender(word, clusters, character_lut): + for c in character_lut: + if c.lower() in [w.lower() for w in word] and character_lut[c]['gender'] in ['Male', 'Female']: + return character_lut[c]['gender'] + + cluster_idx = cluster_word(word, clusters) + if cluster_idx == -1: + return 'neutral' + genders = [pronoun2gender(c.text) for c in clusters[cluster_idx]] + if 'Male' in genders and 'Female' not in 'genders': + return 'Male' + if 'Female' in genders and 'Male' not in 'genders': + return 'Female' + return 'neutral' + +def annotate(text, + ner_model, + rel_model, + character_lut, + saywords_lut, + spind_lut, + places_lut, + do_coreference_resolution=True): + """ + Function which annotates entities in text + using the model in "model", + + returns: A ner.Document object with tokens labelled via + the LUTS provided and also the NER model in "model" + """ + + # Find and store character lines in a dictionary + logging.info('Swapping character lines for character line tags') + processed_text, quotes = quotes2dict(text) + + # Create spacy document object from resulting text + # Create the nlp engine + logging.info("Loading 'en' spacy model") + nlp = spacy.load('en') + + # Loading coreference model + coref = Coref() + + + # Doing coreference resolution + if do_coreference_resolution: + logging.info("Doing one-shot coreference resolution (this might take some time)") + coref.one_shot_coref(processed_text) + resolved_clusters = get_resolved_clusters(coref) + processed_text = coref.get_resolved_utterances()[0] + + # Parse to spacy document + logging.info("Parsing document to spacy") + doc = nlp(processed_text) + + # Parse to our custom Document object + logging.info("Parsing document to our object format for Named Entity Recognition") + mDoc = ner.Document(doc) + + # Label <CLINE[0-9]+> as character line + logging.info("Labeling character lines") + spans = [r.span() for r in re.finditer(r'<cline[0-9]+>\.', mDoc.text)] + for span in spans: + mDoc.assign_label_to_tokens(span[0],span[1],'Character_Line') + + # Parse using LUTs + + # *- Characters + + # Sort by number of words so that tokens with more words override + # tokens with less words in labelling. For example if you have + # `man' and `an old man' as characters, the character labelled is going to + # be `an old man' and not the included `man'. + logging.info("Labeling characters from LUT") + cLUT = [c.lower() for c in sorted(character_lut, key=lambda x: len(x.split()))] + + # Find literals in document that match a character in cLUT + for c in cLUT: + spans = [r.span() for r in re.finditer(c, mDoc.text)] + for span in spans: + mDoc.assign_label_to_tokens(span[0],span[1],'Character') + + # *- Saywords + + # Assign labels to saywords. here saywords contain only one token. In addition + # we check against the saywords' lemma and not the saywords itself. + logging.info("Labeling saywords from LUT") + swLUT = [nlp(sw)[0].lemma_ for sw in saywords_lut] + for sw in swLUT: + mDoc.assign_label_to_tokens_by_matching_lemma(sw, 'Says') + + # *- Places + logging.info("Labeling places from LUT") + plLUT = [pl.lower() for pl in sorted(places_lut, key=lambda x: len(x.split()))] + + # Find literals in document that match a character in cLUT + for pl in plLUT: + spans = [r.span() for r in re.finditer(pl, mDoc.text)] + for span in spans: + mDoc.assign_label_to_tokens(span[0],span[1],'Place') + + # *- Spatial indicators + logging.info("Labeling spatial indicators from LUT") + spLUT = [sp.lower() for sp in sorted(spind_lut, key=lambda x: len(x.split()))] + for sp in spLUT: + spans = [r.span() for r in re.finditer(sp, mDoc.text)] + for span in spans: + mDoc.assign_label_to_tokens(span[0],span[1],'Spatial_Signal') + + logging.info("Extracting token features") + features, labels = mDoc.get_token_features_labels() + + logging.info("Predicting labels") + new_labels = ner_model.predict(features) + + + logging.info("Assigning labels based on the NER model") + # If a label is not already assigned by a LUT, assign it using the model + + #logging.info("{} {}".format(len(mDoc.tokens), len(new_labels))) + for m, sent in enumerate(mDoc.token_sentences): + for n, token in enumerate(sent): + if token.label == 'O': + token.label = new_labels[m][n] + + # Assign character labels + if do_coreference_resolution: + logging.info('Figuring out character genders') + character_tok_sent = mDoc.get_tokens_with_label('Character') + for sent in character_tok_sent: + for character in sent: + raw_string = " ".join([c.text for c in character]) + gender = figure_gender(raw_string, resolved_clusters, character_lut) + for tok in character: + if gender in ['Male', 'Female']: + tok.set_attribute('gender', gender) + + logging.info('Predicting the correct label for all possible relations in Document') + mDoc.predict_relations(rel_model) + + + return mDoc, quotes + + +def doc2brat(mDoc): + """ Returns a brat .ann file str based on mDoc """ + + # Dictionary that maps text span -> variable (to be used when + # adding relations ) + span2var = {} + + # Variable generator for entities (T in brat format) + tvar = ner.var_generator('T') + + # Variable generator for relations (E in brat format) + rvar = ner.var_generator('E') + + # Variable generator for attributions (E in brat format) + avar = ner.var_generator('A') + + ann_str = "" + # Extract characters in the format + # T1 Character START END character string + + labels = ['Character', 'Says', 'Place', 'Spatial_Signal', 'Character_Line'] + + for label in labels: + token_sentences = mDoc.get_tokens_with_label(label) + for tlist in token_sentences: + if len(tlist) == 0: + continue + + for tokens in tlist: + start = tokens[0].start + end = tokens[-1].end + txt = mDoc.text[start:end] + var = next(tvar) + ann_str += "{}\t{} {} {}\t{}\n".format(var, label, start, end, txt) + if 'gender' in tokens[0].attributes: + ann_str += "{}\t{} {} {}\n".format(next(avar), 'Gender', var, tokens[0].attributes['gender']) + + span2var[(start, end)] = var + + # Map relations + for r in mDoc.relations: + var = next(rvar) + trigger = r.trigger + trigger_label = trigger[0].label[2:] + trigger_start = trigger[0].start + trigger_end = trigger[-1].end + trigger_var = span2var[(trigger_start, trigger_end)] + + # If a trigger is Spatial_Signal then the + # arguments are of form Trajector and Landmark + + if trigger_label == 'Spatial_Signal': + arg1_label = 'Trajector' + arg2_label = 'Landmark' + + + # If a trigger is Says then the + # arguments are WHO and WHAT + + elif trigger_label == 'Says': + arg1_label = 'WHO' + arg2_label = 'WHAT' + + # Span for the first argument + arg1_start = r.arg1[0].start + arg1_end = r.arg1[-1].end + + # Variable for the first argument + arg1_var = span2var[(arg1_start, arg1_end)] + + # Span for the second argument + arg2_start = r.arg2[0].start + arg2_end = r.arg2[-1].end + + # Variable for the second argument + arg2_var = span2var[(arg2_start, arg2_end)] + + annot_line = "{}\t{}:{} {}:{} {}:{}\n".format(var, + trigger_label, + trigger_var, + arg1_label, + arg1_var, + arg2_label, + arg2_var) + + ann_str += annot_line + + + + + return ann_str + +if __name__=="__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument('input_path', help='.txt file to parse') + argparser.add_argument('ner_model_path', help='.pkl file containing NER model') + argparser.add_argument('rel_model_path', help='.pkl file containing relational model') + argparser.add_argument('--say-lut', help='.txt file with list of saywords') + argparser.add_argument('--char-lut', help='.txt file with known characters') + argparser.add_argument('--place-lut', help='.txt file with known places') + argparser.add_argument('--spatial-indicator-lut', help='.txt file with known spatial indicators') + argparser.add_argument('--force', help='force overwrite when there is a file to be overwritten') + argparser.add_argument('--no-coreference-resolution', action='store_true', help='omit coreference resolution step') + + args = argparser.parse_args() + + # Load text file + with open(args.input_path) as f: + text = " ".join(f.read().split()) + + output_dir = os.path.dirname(args.input_path) + output_text_path = args.input_path[:-4] + '_processed.txt' + output_quotes_path = args.input_path[:-4] + '_quotes.json' + output_annotation_path = args.input_path[:-4] + '_processed.ann' + + # Load NER model file + ner_model = joblib.load(args.ner_model_path) + + # Load REL model file + rel_model = joblib.load(args.rel_model_path) + + # Load saywords + if args.say_lut: + saylut_path = args.say_lut + else: + saylut_path = 'saywords.txt' + + with open(saylut_path) as f: + saylut = [s for s in f.read().split('\n') if s.strip() != ''] + + # Load places LUT + if args.place_lut: + placelut_path = args.place_lut + else: + placelut_path = 'places.txt' + + with open(placelut_path) as f: + placelut = [s for s in f.read().split('\n') if s.strip() != ''] + + # Load spatial indicators LUT + if args.spatial_indicator_lut: + spatial_indicator_lut_path = args.spatial_indicator_lut + else: + spatial_indicator_lut_path = 'spatial_indicators.txt' + + with open(spatial_indicator_lut_path) as f: + spatial_indicator_lut = [s for s in f.read().split('\n') if s.strip() != ''] + + # Load character LUT + if args.char_lut: + charlut_path = args.char_lut + else: + charlut_path = 'characters.txt' + + with open(charlut_path) as f: + + charlist = [s for s in f.read().split('\n') if s.strip() != ''] # One character per line + + character_lut = {} # Stores character attributes indexed by name + for l in charlist: + name, attributes = l.split(':') + + gender = None + age = None + + for a in attributes.split(','): + if 'male' in a: + gender = a + elif a.lower() in ['young', 'old']: + age = a + + character_lut[name] = {} + if gender: + character_lut[name]['gender'] = gender + if age: + character_lut[name]['age'] = age + + if args.no_coreference_resolution: + corefres = False + else: + corefres = True + mDoc, quotes = annotate(text, ner_model, rel_model, character_lut, saylut, spatial_indicator_lut, placelut, corefres) + + annotation_text = doc2brat(mDoc) + + to_save = { + output_text_path: mDoc.text, + output_quotes_path: json.dumps(quotes), + output_annotation_path: annotation_text + } + + + for path in to_save: + if not os.path.exists(path) or args.force: + with open(path, 'w') as f: + f.write(to_save[path]) + else: + overwrite = input('Path {} exists, overwrite? (y/N) '.format(path)) + if overwrite[0] in ['Y', 'y']: + with open(path, 'w') as f: + f.write(to_save[path]) + + + +