view demo/text2annotation.py @ 0:90155bdd5dd6

first commit
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Wed, 16 May 2018 18:27:05 +0100
parents
children
line wrap: on
line source
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Apr 28 14:17:15 2018

@author: Emmanouil Theofanis Chourdakis

Takes a .txt story and annotates it based on:
    
    characters, 
    places,
    saywords,
    character_lines,
    spatial_indicators,
    
@output:
    .ann file with the same name
    .json file with the extracted character lines
    
"""

import os
import argparse
from sklearn.externals import joblib
import ner
import spacy
import re
import logging
import json
from difflib import SequenceMatcher
from neuralcoref import Coref
from rel import *

def pronoun2gender(word):
    pronoun2gender = {
        'he' : 'Male',
        'him': 'Male',
        'she': 'Female',
        'her': 'Female',
        'his': 'Male',
        'hers': 'Female',
        'himself': 'Male',
        'herself': 'Female',    
    }
    
    if word in pronoun2gender:
        return pronoun2gender[word]
    else:
        return 'neutral'


logging.basicConfig(level=logging.INFO)

# given an iterable of pairs return the key corresponding to the greatest value
def argmax(pairs):
    #https://stackoverflow.com/questions/5098580/implementing-argmax-in-python    
    return max(pairs, key=lambda x: x[1])[0]

# given an iterable of values return the index of the greatest value
def argmax_index(values):
    return argmax(enumerate(values))

# given an iterable of keys and a function f, return the key with largest f(key)
def argmax_f(keys, f):
    return max(keys, key=f)

def similar(a, b):
    """ Returns string similarity between a and b """
    # https://stackoverflow.com/questions/17388213/find-the-similarity-metric-between-two-strings
    return SequenceMatcher(None, a, b).ratio()


def get_resolved_clusters(coref):
    """ Gets a coref object (from neural coref) and
        returns the clusters as words """
    
    mentions = coref.get_mentions()
    clusters = coref.get_clusters()[0]
    result = []
    for c in clusters:
        result.append([mentions[r] for r in clusters[c]])
    return result

def cluster_word(word, clusters):
    """ Gets a word and a list of clusters of mentions
    and figures out where the word matches most based on
    string similarity """
    
    similarities = []
    for rc in clusters:
        similarity = [similar(word.lower(), c.text.lower()) for c in rc]
        similarities.append(similarity)
    max_similarities = [max(s) for s in similarities]
    if max(max_similarities) > 0.75:
        return argmax_index(max_similarities)
    else:
        return -1

def quotes2dict(text):
    new_text = text
    is_open = False
    
    quote_no = 0
    quote = []
    narrator = []
    quote_dict = {}
    
    for n, c in enumerate(text):
        if c == '"' and not is_open:
            is_open = True
            quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
            narrator = []
            quote_no += 1
            continue
        
        elif c == '"' and is_open:
            is_open = False
            quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
            new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
            quote = []
            quote_no += 1
            continue
        
        if is_open:
            quote.append(c)
        elif not is_open:
            narrator.append(c)
            
    return new_text, quote_dict

def figure_gender(word, clusters, character_lut):
    for c in character_lut:
        if c.lower() in [w.lower() for w in word] and character_lut[c]['gender'] in ['Male', 'Female']:
            return character_lut[c]['gender']
            
    cluster_idx = cluster_word(word, clusters)
    if cluster_idx == -1:
        return 'neutral'
    genders = [pronoun2gender(c.text) for c in clusters[cluster_idx]]
    if 'Male' in genders and 'Female' not in 'genders':
        return 'Male'
    if 'Female' in genders and 'Male' not in 'genders':
        return 'Female'
    return 'neutral'

def annotate(text, 
             ner_model, 
             rel_model,
             character_lut, 
             saywords_lut, 
             spind_lut, 
             places_lut,
             do_coreference_resolution=True):
    """
        Function which annotates entities in text
        using the model in "model",
        
        returns: A ner.Document object with tokens labelled via
                 the LUTS provided and also the NER model in "model"
    """

    # Find and store character lines in a dictionary
    logging.info('Swapping character lines for character line tags')
    processed_text, quotes = quotes2dict(text)

    # Create spacy document object  from resulting text
    # Create the nlp engine
    logging.info("Loading 'en' spacy model")
    nlp = spacy.load('en')
    
    # Loading coreference model
    coref = Coref()
    
    
    # Doing coreference resolution
    if do_coreference_resolution:
        logging.info("Doing one-shot coreference resolution (this might take some time)")
        coref.one_shot_coref(processed_text)
        resolved_clusters = get_resolved_clusters(coref)
        processed_text = coref.get_resolved_utterances()[0]
        
    # Parse to spacy document
    logging.info("Parsing document to spacy")
    doc = nlp(processed_text)    
    
    # Parse to our custom Document object
    logging.info("Parsing document to our object format for Named Entity Recognition")
    mDoc = ner.Document(doc)
    
    # Label <CLINE[0-9]+> as character line
    logging.info("Labeling character lines")
    spans = [r.span() for r in re.finditer(r'<cline[0-9]+>\.', mDoc.text)]
    for span in spans:
        mDoc.assign_label_to_tokens(span[0],span[1],'Character_Line')

    # Parse using LUTs
    
    # *- Characters
    
    # Sort by number of words so that tokens with more words override
    # tokens with less words in labelling. For example if you have 
    # `man' and `an old man' as characters, the character labelled is going to
    # be `an old man' and not the included `man'.
    logging.info("Labeling characters from LUT")
    cLUT = [c.lower() for c in sorted(character_lut, key=lambda x: len(x.split()))]
    
    # Find literals in document that match a character in cLUT 
    for c in cLUT:
        spans = [r.span() for r in re.finditer(c, mDoc.text)]
        for span in spans:
            mDoc.assign_label_to_tokens(span[0],span[1],'Character')
             
    # *- Saywords
   
    # Assign labels to saywords. here saywords contain only one token. In addition
    # we check against the saywords' lemma and not the saywords itself. 
    logging.info("Labeling saywords from LUT")
    swLUT = [nlp(sw)[0].lemma_ for sw in saywords_lut]
    for sw in swLUT:
        mDoc.assign_label_to_tokens_by_matching_lemma(sw, 'Says')
        
    # *- Places
    logging.info("Labeling places from LUT")
    plLUT = [pl.lower() for pl in sorted(places_lut, key=lambda x: len(x.split()))]
    
    # Find literals in document that match a character in cLUT 
    for pl in plLUT:
        spans = [r.span() for r in re.finditer(pl, mDoc.text)]
        for span in spans:
            mDoc.assign_label_to_tokens(span[0],span[1],'Place')
            
    # *- Spatial indicators
    logging.info("Labeling spatial indicators from LUT")
    spLUT = [sp.lower() for sp in sorted(spind_lut, key=lambda x: len(x.split()))]         
    for sp in spLUT:
        spans = [r.span() for r in re.finditer(sp, mDoc.text)]
        for span in spans:
            mDoc.assign_label_to_tokens(span[0],span[1],'Spatial_Signal')
        
    logging.info("Extracting token features")
    features, labels = mDoc.get_token_features_labels()
    
    logging.info("Predicting labels")
    new_labels = ner_model.predict(features)
    

    logging.info("Assigning labels based on the NER model")
    # If a label is not already assigned by a LUT, assign it using the model
    
    #logging.info("{} {}".format(len(mDoc.tokens), len(new_labels)))
    for m, sent in enumerate(mDoc.token_sentences):
        for n, token in enumerate(sent):
            if token.label == 'O':
                token.label = new_labels[m][n]
                
    # Assign character labels
    if do_coreference_resolution:
        logging.info('Figuring out character genders')
        character_tok_sent = mDoc.get_tokens_with_label('Character')
        for sent in character_tok_sent:
            for character in sent:
                raw_string = " ".join([c.text for c in character])
                gender = figure_gender(raw_string, resolved_clusters, character_lut)
                for tok in character:
                    if gender in ['Male', 'Female']:
                        tok.set_attribute('gender', gender)
                    
    logging.info('Predicting the correct label for all possible relations in Document')
    mDoc.predict_relations(rel_model)
    
    
    return mDoc, quotes


def doc2brat(mDoc):
    """ Returns a brat .ann file str based on mDoc """
    
    # Dictionary that maps text span -> variable (to be used when 
    # adding relations )
    span2var = {}
    
    # Variable generator for entities (T in brat format)
    tvar = ner.var_generator('T')
    
    # Variable generator for relations (E in brat format)
    rvar = ner.var_generator('E')
    
    # Variable generator for attributions (E in brat format)
    avar = ner.var_generator('A')
    
    ann_str = ""
    # Extract characters in the format 
    # T1 Character START END character string
    
    labels = ['Character', 'Says', 'Place', 'Spatial_Signal', 'Character_Line']
    
    for label in labels:
        token_sentences = mDoc.get_tokens_with_label(label)
        for tlist in token_sentences:
            if len(tlist) == 0:
                continue
            
            for tokens in tlist:
                start = tokens[0].start
                end = tokens[-1].end
                txt = mDoc.text[start:end]
                var = next(tvar)
                ann_str += "{}\t{} {} {}\t{}\n".format(var, label, start, end, txt)
                if 'gender' in tokens[0].attributes:
                    ann_str += "{}\t{} {} {}\n".format(next(avar), 'Gender', var, tokens[0].attributes['gender'])
                
                span2var[(start, end)] = var
                
    # Map relations 
    for r in mDoc.relations:
        var = next(rvar)
        trigger = r.trigger
        trigger_label = trigger[0].label[2:]
        trigger_start = trigger[0].start
        trigger_end = trigger[-1].end
        trigger_var = span2var[(trigger_start, trigger_end)]
        
        # If a trigger is Spatial_Signal then the 
        # arguments are of form Trajector and Landmark
        
        if trigger_label == 'Spatial_Signal':
            arg1_label = 'Trajector'
            arg2_label = 'Landmark'


        # If a trigger is Says then the 
        # arguments are WHO and WHAT
        
        elif trigger_label == 'Says':
            arg1_label = 'WHO'
            arg2_label = 'WHAT'
        
        # Span for the first argument
        arg1_start = r.arg1[0].start
        arg1_end = r.arg1[-1].end
        
        # Variable for the first argument
        arg1_var = span2var[(arg1_start, arg1_end)]      
        
        # Span for the second argument
        arg2_start = r.arg2[0].start
        arg2_end = r.arg2[-1].end
        
        # Variable for the second argument
        arg2_var = span2var[(arg2_start, arg2_end)]        
            
        annot_line = "{}\t{}:{} {}:{} {}:{}\n".format(var, 
                                                      trigger_label, 
                                                      trigger_var, 
                                                      arg1_label,
                                                      arg1_var,
                                                      arg2_label,
                                                      arg2_var)
        
        ann_str += annot_line
        
        
            
            
    return ann_str

if __name__=="__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument('input_path', help='.txt file to parse')
    argparser.add_argument('ner_model_path', help='.pkl file containing NER model')
    argparser.add_argument('rel_model_path', help='.pkl file containing relational model')
    argparser.add_argument('--say-lut', help='.txt file with list of saywords')
    argparser.add_argument('--char-lut', help='.txt file with known characters')
    argparser.add_argument('--place-lut', help='.txt file with known places')
    argparser.add_argument('--spatial-indicator-lut', help='.txt file with known spatial indicators')
    argparser.add_argument('--force', help='force overwrite when there is a file to be overwritten')
    argparser.add_argument('--no-coreference-resolution', action='store_true', help='omit coreference resolution step')
    
    args = argparser.parse_args()
    
    # Load text file
    with open(args.input_path) as f:
        text = " ".join(f.read().split())
        
    output_dir = os.path.dirname(args.input_path)
    output_text_path = args.input_path[:-4] + '_processed.txt'
    output_quotes_path = args.input_path[:-4] + '_quotes.json'
    output_annotation_path = args.input_path[:-4] + '_processed.ann'
        
    # Load NER model file
    ner_model = joblib.load(args.ner_model_path)
    
    # Load REL model file
    rel_model = joblib.load(args.rel_model_path)
    
    # Load saywords
    if args.say_lut:
        saylut_path = args.say_lut
    else:
        saylut_path = 'saywords.txt'
        
    with open(saylut_path) as f:
        saylut = [s for s in f.read().split('\n') if s.strip() != '']
        
    # Load places LUT
    if args.place_lut:
        placelut_path = args.place_lut
    else:
        placelut_path = 'places.txt'
        
    with open(placelut_path) as f:
        placelut = [s for s in f.read().split('\n') if s.strip() != '']
        
    # Load spatial indicators LUT
    if args.spatial_indicator_lut:
        spatial_indicator_lut_path = args.spatial_indicator_lut
    else:
        spatial_indicator_lut_path = 'spatial_indicators.txt'
        
    with open(spatial_indicator_lut_path) as f:
        spatial_indicator_lut = [s for s in f.read().split('\n') if s.strip() != '']        
        
    # Load character LUT
    if args.char_lut:
        charlut_path = args.char_lut
    else:
        charlut_path = 'characters.txt'
        
    with open(charlut_path) as f:
        
        charlist = [s for s in f.read().split('\n') if s.strip() != ''] # One character per line
        
        character_lut = {} # Stores character attributes indexed by name
        for l in charlist:
            name, attributes = l.split(':') 
            
            gender = None
            age = None
            
            for a in attributes.split(','):
                if 'male' in a:
                    gender = a
                elif a.lower() in ['young', 'old']:
                    age = a
            
            character_lut[name] = {}
            if gender:
                character_lut[name]['gender'] = gender
            if age:
                character_lut[name]['age'] = age
        
    if args.no_coreference_resolution:
        corefres = False
    else:
        corefres = True
    mDoc, quotes = annotate(text, ner_model, rel_model, character_lut, saylut, spatial_indicator_lut, placelut, corefres)
    
    annotation_text = doc2brat(mDoc)
    
    to_save = {
            output_text_path: mDoc.text,
            output_quotes_path: json.dumps(quotes),
            output_annotation_path: annotation_text
            }
    
    
    for path in to_save:
        if not os.path.exists(path) or args.force:
            with open(path, 'w') as f:
                f.write(to_save[path])               
        else:
            overwrite = input('Path {} exists, overwrite? (y/N) '.format(path))
            if overwrite[0] in ['Y', 'y']:
                with open(path, 'w') as f:
                    f.write(to_save[path])