diff demo/text2annotation.py @ 0:4dad87badb0c

initial commit
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Wed, 16 May 2018 17:56:10 +0100
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/demo/text2annotation.py	Wed May 16 17:56:10 2018 +0100
@@ -0,0 +1,479 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Sat Apr 28 14:17:15 2018
+
+@author: Emmanouil Theofanis Chourdakis
+
+Takes a .txt story and annotates it based on:
+    
+    characters, 
+    places,
+    saywords,
+    character_lines,
+    spatial_indicators,
+    
+@output:
+    .ann file with the same name
+    .json file with the extracted character lines
+    
+"""
+
+import os
+import argparse
+from sklearn.externals import joblib
+import ner
+import spacy
+import re
+import logging
+import json
+from difflib import SequenceMatcher
+from neuralcoref import Coref
+from rel import *
+
+def pronoun2gender(word):
+    pronoun2gender = {
+        'he' : 'Male',
+        'him': 'Male',
+        'she': 'Female',
+        'her': 'Female',
+        'his': 'Male',
+        'hers': 'Female',
+        'himself': 'Male',
+        'herself': 'Female',    
+    }
+    
+    if word in pronoun2gender:
+        return pronoun2gender[word]
+    else:
+        return 'neutral'
+
+
+logging.basicConfig(level=logging.INFO)
+
+# given an iterable of pairs return the key corresponding to the greatest value
+def argmax(pairs):
+    #https://stackoverflow.com/questions/5098580/implementing-argmax-in-python    
+    return max(pairs, key=lambda x: x[1])[0]
+
+# given an iterable of values return the index of the greatest value
+def argmax_index(values):
+    return argmax(enumerate(values))
+
+# given an iterable of keys and a function f, return the key with largest f(key)
+def argmax_f(keys, f):
+    return max(keys, key=f)
+
+def similar(a, b):
+    """ Returns string similarity between a and b """
+    # https://stackoverflow.com/questions/17388213/find-the-similarity-metric-between-two-strings
+    return SequenceMatcher(None, a, b).ratio()
+
+
+def get_resolved_clusters(coref):
+    """ Gets a coref object (from neural coref) and
+        returns the clusters as words """
+    
+    mentions = coref.get_mentions()
+    clusters = coref.get_clusters()[0]
+    result = []
+    for c in clusters:
+        result.append([mentions[r] for r in clusters[c]])
+    return result
+
+def cluster_word(word, clusters):
+    """ Gets a word and a list of clusters of mentions
+    and figures out where the word matches most based on
+    string similarity """
+    
+    similarities = []
+    for rc in clusters:
+        similarity = [similar(word.lower(), c.text.lower()) for c in rc]
+        similarities.append(similarity)
+    max_similarities = [max(s) for s in similarities]
+    if max(max_similarities) > 0.75:
+        return argmax_index(max_similarities)
+    else:
+        return -1
+
+def quotes2dict(text):
+    new_text = text
+    is_open = False
+    
+    quote_no = 0
+    quote = []
+    narrator = []
+    quote_dict = {}
+    
+    for n, c in enumerate(text):
+        if c == '"' and not is_open:
+            is_open = True
+            quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
+            narrator = []
+            quote_no += 1
+            continue
+        
+        elif c == '"' and is_open:
+            is_open = False
+            quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
+            new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
+            quote = []
+            quote_no += 1
+            continue
+        
+        if is_open:
+            quote.append(c)
+        elif not is_open:
+            narrator.append(c)
+            
+    return new_text, quote_dict
+
+def figure_gender(word, clusters, character_lut):
+    for c in character_lut:
+        if c.lower() in [w.lower() for w in word] and character_lut[c]['gender'] in ['Male', 'Female']:
+            return character_lut[c]['gender']
+            
+    cluster_idx = cluster_word(word, clusters)
+    if cluster_idx == -1:
+        return 'neutral'
+    genders = [pronoun2gender(c.text) for c in clusters[cluster_idx]]
+    if 'Male' in genders and 'Female' not in 'genders':
+        return 'Male'
+    if 'Female' in genders and 'Male' not in 'genders':
+        return 'Female'
+    return 'neutral'
+
+def annotate(text, 
+             ner_model, 
+             rel_model,
+             character_lut, 
+             saywords_lut, 
+             spind_lut, 
+             places_lut,
+             do_coreference_resolution=True):
+    """
+        Function which annotates entities in text
+        using the model in "model",
+        
+        returns: A ner.Document object with tokens labelled via
+                 the LUTS provided and also the NER model in "model"
+    """
+
+    # Find and store character lines in a dictionary
+    logging.info('Swapping character lines for character line tags')
+    processed_text, quotes = quotes2dict(text)
+
+    # Create spacy document object  from resulting text
+    # Create the nlp engine
+    logging.info("Loading 'en' spacy model")
+    nlp = spacy.load('en')
+    
+    # Loading coreference model
+    coref = Coref()
+    
+    
+    # Doing coreference resolution
+    if do_coreference_resolution:
+        logging.info("Doing one-shot coreference resolution (this might take some time)")
+        coref.one_shot_coref(processed_text)
+        resolved_clusters = get_resolved_clusters(coref)
+        processed_text = coref.get_resolved_utterances()[0]
+        
+    # Parse to spacy document
+    logging.info("Parsing document to spacy")
+    doc = nlp(processed_text)    
+    
+    # Parse to our custom Document object
+    logging.info("Parsing document to our object format for Named Entity Recognition")
+    mDoc = ner.Document(doc)
+    
+    # Label <CLINE[0-9]+> as character line
+    logging.info("Labeling character lines")
+    spans = [r.span() for r in re.finditer(r'<cline[0-9]+>\.', mDoc.text)]
+    for span in spans:
+        mDoc.assign_label_to_tokens(span[0],span[1],'Character_Line')
+
+    # Parse using LUTs
+    
+    # *- Characters
+    
+    # Sort by number of words so that tokens with more words override
+    # tokens with less words in labelling. For example if you have 
+    # `man' and `an old man' as characters, the character labelled is going to
+    # be `an old man' and not the included `man'.
+    logging.info("Labeling characters from LUT")
+    cLUT = [c.lower() for c in sorted(character_lut, key=lambda x: len(x.split()))]
+    
+    # Find literals in document that match a character in cLUT 
+    for c in cLUT:
+        spans = [r.span() for r in re.finditer(c, mDoc.text)]
+        for span in spans:
+            mDoc.assign_label_to_tokens(span[0],span[1],'Character')
+             
+    # *- Saywords
+   
+    # Assign labels to saywords. here saywords contain only one token. In addition
+    # we check against the saywords' lemma and not the saywords itself. 
+    logging.info("Labeling saywords from LUT")
+    swLUT = [nlp(sw)[0].lemma_ for sw in saywords_lut]
+    for sw in swLUT:
+        mDoc.assign_label_to_tokens_by_matching_lemma(sw, 'Says')
+        
+    # *- Places
+    logging.info("Labeling places from LUT")
+    plLUT = [pl.lower() for pl in sorted(places_lut, key=lambda x: len(x.split()))]
+    
+    # Find literals in document that match a character in cLUT 
+    for pl in plLUT:
+        spans = [r.span() for r in re.finditer(pl, mDoc.text)]
+        for span in spans:
+            mDoc.assign_label_to_tokens(span[0],span[1],'Place')
+            
+    # *- Spatial indicators
+    logging.info("Labeling spatial indicators from LUT")
+    spLUT = [sp.lower() for sp in sorted(spind_lut, key=lambda x: len(x.split()))]         
+    for sp in spLUT:
+        spans = [r.span() for r in re.finditer(sp, mDoc.text)]
+        for span in spans:
+            mDoc.assign_label_to_tokens(span[0],span[1],'Spatial_Signal')
+        
+    logging.info("Extracting token features")
+    features, labels = mDoc.get_token_features_labels()
+    
+    logging.info("Predicting labels")
+    new_labels = ner_model.predict(features)
+    
+
+    logging.info("Assigning labels based on the NER model")
+    # If a label is not already assigned by a LUT, assign it using the model
+    
+    #logging.info("{} {}".format(len(mDoc.tokens), len(new_labels)))
+    for m, sent in enumerate(mDoc.token_sentences):
+        for n, token in enumerate(sent):
+            if token.label == 'O':
+                token.label = new_labels[m][n]
+                
+    # Assign character labels
+    if do_coreference_resolution:
+        logging.info('Figuring out character genders')
+        character_tok_sent = mDoc.get_tokens_with_label('Character')
+        for sent in character_tok_sent:
+            for character in sent:
+                raw_string = " ".join([c.text for c in character])
+                gender = figure_gender(raw_string, resolved_clusters, character_lut)
+                for tok in character:
+                    if gender in ['Male', 'Female']:
+                        tok.set_attribute('gender', gender)
+                    
+    logging.info('Predicting the correct label for all possible relations in Document')
+    mDoc.predict_relations(rel_model)
+    
+    
+    return mDoc, quotes
+
+
+def doc2brat(mDoc):
+    """ Returns a brat .ann file str based on mDoc """
+    
+    # Dictionary that maps text span -> variable (to be used when 
+    # adding relations )
+    span2var = {}
+    
+    # Variable generator for entities (T in brat format)
+    tvar = ner.var_generator('T')
+    
+    # Variable generator for relations (E in brat format)
+    rvar = ner.var_generator('E')
+    
+    # Variable generator for attributions (E in brat format)
+    avar = ner.var_generator('A')
+    
+    ann_str = ""
+    # Extract characters in the format 
+    # T1 Character START END character string
+    
+    labels = ['Character', 'Says', 'Place', 'Spatial_Signal', 'Character_Line']
+    
+    for label in labels:
+        token_sentences = mDoc.get_tokens_with_label(label)
+        for tlist in token_sentences:
+            if len(tlist) == 0:
+                continue
+            
+            for tokens in tlist:
+                start = tokens[0].start
+                end = tokens[-1].end
+                txt = mDoc.text[start:end]
+                var = next(tvar)
+                ann_str += "{}\t{} {} {}\t{}\n".format(var, label, start, end, txt)
+                if 'gender' in tokens[0].attributes:
+                    ann_str += "{}\t{} {} {}\n".format(next(avar), 'Gender', var, tokens[0].attributes['gender'])
+                
+                span2var[(start, end)] = var
+                
+    # Map relations 
+    for r in mDoc.relations:
+        var = next(rvar)
+        trigger = r.trigger
+        trigger_label = trigger[0].label[2:]
+        trigger_start = trigger[0].start
+        trigger_end = trigger[-1].end
+        trigger_var = span2var[(trigger_start, trigger_end)]
+        
+        # If a trigger is Spatial_Signal then the 
+        # arguments are of form Trajector and Landmark
+        
+        if trigger_label == 'Spatial_Signal':
+            arg1_label = 'Trajector'
+            arg2_label = 'Landmark'
+
+
+        # If a trigger is Says then the 
+        # arguments are WHO and WHAT
+        
+        elif trigger_label == 'Says':
+            arg1_label = 'WHO'
+            arg2_label = 'WHAT'
+        
+        # Span for the first argument
+        arg1_start = r.arg1[0].start
+        arg1_end = r.arg1[-1].end
+        
+        # Variable for the first argument
+        arg1_var = span2var[(arg1_start, arg1_end)]      
+        
+        # Span for the second argument
+        arg2_start = r.arg2[0].start
+        arg2_end = r.arg2[-1].end
+        
+        # Variable for the second argument
+        arg2_var = span2var[(arg2_start, arg2_end)]        
+            
+        annot_line = "{}\t{}:{} {}:{} {}:{}\n".format(var, 
+                                                      trigger_label, 
+                                                      trigger_var, 
+                                                      arg1_label,
+                                                      arg1_var,
+                                                      arg2_label,
+                                                      arg2_var)
+        
+        ann_str += annot_line
+        
+        
+            
+            
+    return ann_str
+
+if __name__=="__main__":
+    argparser = argparse.ArgumentParser()
+    argparser.add_argument('input_path', help='.txt file to parse')
+    argparser.add_argument('ner_model_path', help='.pkl file containing NER model')
+    argparser.add_argument('rel_model_path', help='.pkl file containing relational model')
+    argparser.add_argument('--say-lut', help='.txt file with list of saywords')
+    argparser.add_argument('--char-lut', help='.txt file with known characters')
+    argparser.add_argument('--place-lut', help='.txt file with known places')
+    argparser.add_argument('--spatial-indicator-lut', help='.txt file with known spatial indicators')
+    argparser.add_argument('--force', help='force overwrite when there is a file to be overwritten')
+    argparser.add_argument('--no-coreference-resolution', action='store_true', help='omit coreference resolution step')
+    
+    args = argparser.parse_args()
+    
+    # Load text file
+    with open(args.input_path) as f:
+        text = " ".join(f.read().split())
+        
+    output_dir = os.path.dirname(args.input_path)
+    output_text_path = args.input_path[:-4] + '_processed.txt'
+    output_quotes_path = args.input_path[:-4] + '_quotes.json'
+    output_annotation_path = args.input_path[:-4] + '_processed.ann'
+        
+    # Load NER model file
+    ner_model = joblib.load(args.ner_model_path)
+    
+    # Load REL model file
+    rel_model = joblib.load(args.rel_model_path)
+    
+    # Load saywords
+    if args.say_lut:
+        saylut_path = args.say_lut
+    else:
+        saylut_path = 'saywords.txt'
+        
+    with open(saylut_path) as f:
+        saylut = [s for s in f.read().split('\n') if s.strip() != '']
+        
+    # Load places LUT
+    if args.place_lut:
+        placelut_path = args.place_lut
+    else:
+        placelut_path = 'places.txt'
+        
+    with open(placelut_path) as f:
+        placelut = [s for s in f.read().split('\n') if s.strip() != '']
+        
+    # Load spatial indicators LUT
+    if args.spatial_indicator_lut:
+        spatial_indicator_lut_path = args.spatial_indicator_lut
+    else:
+        spatial_indicator_lut_path = 'spatial_indicators.txt'
+        
+    with open(spatial_indicator_lut_path) as f:
+        spatial_indicator_lut = [s for s in f.read().split('\n') if s.strip() != '']        
+        
+    # Load character LUT
+    if args.char_lut:
+        charlut_path = args.char_lut
+    else:
+        charlut_path = 'characters.txt'
+        
+    with open(charlut_path) as f:
+        
+        charlist = [s for s in f.read().split('\n') if s.strip() != ''] # One character per line
+        
+        character_lut = {} # Stores character attributes indexed by name
+        for l in charlist:
+            name, attributes = l.split(':') 
+            
+            gender = None
+            age = None
+            
+            for a in attributes.split(','):
+                if 'male' in a:
+                    gender = a
+                elif a.lower() in ['young', 'old']:
+                    age = a
+            
+            character_lut[name] = {}
+            if gender:
+                character_lut[name]['gender'] = gender
+            if age:
+                character_lut[name]['age'] = age
+        
+    if args.no_coreference_resolution:
+        corefres = False
+    else:
+        corefres = True
+    mDoc, quotes = annotate(text, ner_model, rel_model, character_lut, saylut, spatial_indicator_lut, placelut, corefres)
+    
+    annotation_text = doc2brat(mDoc)
+    
+    to_save = {
+            output_text_path: mDoc.text,
+            output_quotes_path: json.dumps(quotes),
+            output_annotation_path: annotation_text
+            }
+    
+    
+    for path in to_save:
+        if not os.path.exists(path) or args.force:
+            with open(path, 'w') as f:
+                f.write(to_save[path])               
+        else:
+            overwrite = input('Path {} exists, overwrite? (y/N) '.format(path))
+            if overwrite[0] in ['Y', 'y']:
+                with open(path, 'w') as f:
+                    f.write(to_save[path]) 
+            
+  
+        
+