view demo/annotation2model.py @ 0:90155bdd5dd6

first commit
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Wed, 16 May 2018 18:27:05 +0100
parents
children
line wrap: on
line source
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 30 14:28:49 2018

@author: Emmanouil Theofanis Chourdakis

Takes a .txt story and an .ann annotation and trains a model.

@output:
    ner .pkl model -- NER recognition model
    rel .pkl model -- RELation extraction model

"""

import os

import argparse
import logging
import spacy
import ner
import pypeg2 as pg
import sklearn_crfsuite as crf
import pickle

logging.basicConfig(level=logging.INFO)

# https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
flatten = lambda l: [item for sublist in l for item in sublist]

# Relation Model 

from rel import *

def quotes2dict(text):
    new_text = text
    is_open = False
    
    quote_no = 0
    quote = []
    narrator = []
    quote_dict = {}
    
    for n, c in enumerate(text):
        if c == '"' and not is_open:
            is_open = True
            quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
            narrator = []
            quote_no += 1
            continue
        
        elif c == '"' and is_open:
            is_open = False
            quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
            new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
            quote = []
            quote_no += 1
            continue
        
        if is_open:
            quote.append(c)
        elif not is_open:
            narrator.append(c)
            
    return new_text, quote_dict

def annotation2doc(text, annotation):
    
    # Load language engine
    logging.info('Loading language engine')
    nlp = spacy.load('en')
    
    # Convert to spacy document type
    logging.info('Parsing to spacy document')
    doc = nlp(text)
    
    # Convert to ner.Document
    logging.info('Converting to custom Document format')
    mDoc = ner.Document(doc)
    
    # Parsing annotation
    logging.info('Parsing annotation')
    parsed = pg.parse(annotation, ner.AnnotationFile)
    
    # Store an entity and relations dictionary since relations
    # point to such entities
    
    dictionary = {}
    
    # Visit all the parsed lines. Do it in two passes, first parse 
    # entities and then relations. The reason for that is that some times
    # a relation refers to an entity that has not been defined.
    
    for line in parsed:
        # Every annotation line has a single object
        obj = line[0]
        
        if isinstance(obj, ner.AnnotationTuple):

            # If it is a tuple, find the start and end 
            # borders, and assign them the appropriate label
            
            start_s, end_s = obj.idx.split()
            start = int(start_s)
            end   = int(end_s)
            label = str(obj.type)
            
            # Store to dictionary the string relating 
            # to the annotation 
            
            dictionary[obj.variable] = mDoc.find_tokens(start, end)            
            
            mDoc.assign_label_to_tokens(start, end, label)
            
    for line in parsed:
        # Every annotation line has a single object
        obj = line[0]            
            
        if isinstance(obj, ner.RelationTuple):
            
            # Relations have a trigger, a first argument `arg1' and a 
            # second argument `arg2'. There are going to be 
            # |arg1| * |arg2| relations constructed for each trigger
            # where |arg1| is the number of candidates for argument 1
            # and |arg2| the number of candidates for argument 2
            
            arg1_candidates = []
            arg2_candidates = []
            
            # Check relation's arguments:            
            for arg in obj.args:
                if arg.label == 'Says':
                    trigger = dictionary[arg.target]
                    label = 'Quote'
                elif arg.label == 'Spatial_Signal':
                    trigger = dictionary[arg.target]
                    label = 'Spatial_Relation'
                if arg.label in ['Trajector', 'WHO']:
                    arg1_candidates.append(dictionary[arg.target])
                if arg.label in ['Landmark', 'WHAT']:
                    arg2_candidates.append(dictionary[arg.target])
                    
            for arg1 in arg1_candidates:
                for arg2 in arg2_candidates:
                    mDoc.add_relation(trigger, arg1, arg2, label)
                    
    # Create NER model
    logging.info('Creating NER CRF model')
    
    ner_model = crf.CRF(c1=0.1, 
                        c2=0.1, 
                        max_iterations=100,
                        all_possible_transitions=True)
    
    logging.info('Extracting features/labels from document')
    features, labels = mDoc.get_token_features_labels()
    
    logging.info('Fitting NER model')
    ner_model.fit(features, labels)
    
    # Create Relational model
    logging.info('Creating REL SVM model')
    rel_model = RelModel()
    
    logging.info('Extracting relations features/labels from document')
    rel_features, rel_labels = mDoc.get_candidate_relation_feature_labels()
    
    logging.info('Fitting REL model')
    rel_model.fit(rel_features, rel_labels)

    return mDoc, ner_model, rel_model
    

if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument('input_text_path', 
                           help='.txt file of input')
    argparser.add_argument('input_annotation_path', 
                           help='.ann file of annotation')
    argparser.add_argument('--output-dir', 
                           help='directory to save model files (default `.`')
    
    args = argparser.parse_args()
    
    # Load text and annotation contents
    with open(args.input_text_path) as f:
        text = f.read()
        
    with open(args.input_annotation_path) as f:
        annotation = f.read()
        
        
    mDoc, ner_model, rel_model = annotation2doc(text, annotation)   
    
    if args.output_dir:
        output_dir = args.output_dir
    else:
        output_dir = os.path.curdir
        
    ner_model_path = os.path.join(output_dir, 'ner_model.pkl')        
    rel_model_path = os.path.join(output_dir, 'rel_model.pkl')        
        
    logging.info('Saving NER model to {}'.format(ner_model_path))
    with open(ner_model_path, 'wb') as f:
        pickle.dump(ner_model, f, pickle.HIGHEST_PROTOCOL)
    
    logging.info('Saving REL model to {}'.format(rel_model_path))
    with open(rel_model_path, 'wb') as f:
        pickle.dump(rel_model, f, pickle.HIGHEST_PROTOCOL)