diff demo/annotation2model.py @ 0:4dad87badb0c

initial commit
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Wed, 16 May 2018 17:56:10 +0100
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/demo/annotation2model.py	Wed May 16 17:56:10 2018 +0100
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Apr 30 14:28:49 2018
+
+@author: Emmanouil Theofanis Chourdakis
+
+Takes a .txt story and an .ann annotation and trains a model.
+
+@output:
+    ner .pkl model -- NER recognition model
+    rel .pkl model -- RELation extraction model
+
+"""
+
+import os
+
+import argparse
+import logging
+import spacy
+import ner
+import pypeg2 as pg
+import sklearn_crfsuite as crf
+import pickle
+
+logging.basicConfig(level=logging.INFO)
+
+# https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
+flatten = lambda l: [item for sublist in l for item in sublist]
+
+# Relation Model 
+
+from rel import *
+
+def quotes2dict(text):
+    new_text = text
+    is_open = False
+    
+    quote_no = 0
+    quote = []
+    narrator = []
+    quote_dict = {}
+    
+    for n, c in enumerate(text):
+        if c == '"' and not is_open:
+            is_open = True
+            quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
+            narrator = []
+            quote_no += 1
+            continue
+        
+        elif c == '"' and is_open:
+            is_open = False
+            quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
+            new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
+            quote = []
+            quote_no += 1
+            continue
+        
+        if is_open:
+            quote.append(c)
+        elif not is_open:
+            narrator.append(c)
+            
+    return new_text, quote_dict
+
+def annotation2doc(text, annotation):
+    
+    # Load language engine
+    logging.info('Loading language engine')
+    nlp = spacy.load('en')
+    
+    # Convert to spacy document type
+    logging.info('Parsing to spacy document')
+    doc = nlp(text)
+    
+    # Convert to ner.Document
+    logging.info('Converting to custom Document format')
+    mDoc = ner.Document(doc)
+    
+    # Parsing annotation
+    logging.info('Parsing annotation')
+    parsed = pg.parse(annotation, ner.AnnotationFile)
+    
+    # Store an entity and relations dictionary since relations
+    # point to such entities
+    
+    dictionary = {}
+    
+    # Visit all the parsed lines. Do it in two passes, first parse 
+    # entities and then relations. The reason for that is that some times
+    # a relation refers to an entity that has not been defined.
+    
+    for line in parsed:
+        # Every annotation line has a single object
+        obj = line[0]
+        
+        if isinstance(obj, ner.AnnotationTuple):
+
+            # If it is a tuple, find the start and end 
+            # borders, and assign them the appropriate label
+            
+            start_s, end_s = obj.idx.split()
+            start = int(start_s)
+            end   = int(end_s)
+            label = str(obj.type)
+            
+            # Store to dictionary the string relating 
+            # to the annotation 
+            
+            dictionary[obj.variable] = mDoc.find_tokens(start, end)            
+            
+            mDoc.assign_label_to_tokens(start, end, label)
+            
+    for line in parsed:
+        # Every annotation line has a single object
+        obj = line[0]            
+            
+        if isinstance(obj, ner.RelationTuple):
+            
+            # Relations have a trigger, a first argument `arg1' and a 
+            # second argument `arg2'. There are going to be 
+            # |arg1| * |arg2| relations constructed for each trigger
+            # where |arg1| is the number of candidates for argument 1
+            # and |arg2| the number of candidates for argument 2
+            
+            arg1_candidates = []
+            arg2_candidates = []
+            
+            # Check relation's arguments:            
+            for arg in obj.args:
+                if arg.label == 'Says':
+                    trigger = dictionary[arg.target]
+                    label = 'Quote'
+                elif arg.label == 'Spatial_Signal':
+                    trigger = dictionary[arg.target]
+                    label = 'Spatial_Relation'
+                if arg.label in ['Trajector', 'WHO']:
+                    arg1_candidates.append(dictionary[arg.target])
+                if arg.label in ['Landmark', 'WHAT']:
+                    arg2_candidates.append(dictionary[arg.target])
+                    
+            for arg1 in arg1_candidates:
+                for arg2 in arg2_candidates:
+                    mDoc.add_relation(trigger, arg1, arg2, label)
+                    
+    # Create NER model
+    logging.info('Creating NER CRF model')
+    
+    ner_model = crf.CRF(c1=0.1, 
+                        c2=0.1, 
+                        max_iterations=100,
+                        all_possible_transitions=True)
+    
+    logging.info('Extracting features/labels from document')
+    features, labels = mDoc.get_token_features_labels()
+    
+    logging.info('Fitting NER model')
+    ner_model.fit(features, labels)
+    
+    # Create Relational model
+    logging.info('Creating REL SVM model')
+    rel_model = RelModel()
+    
+    logging.info('Extracting relations features/labels from document')
+    rel_features, rel_labels = mDoc.get_candidate_relation_feature_labels()
+    
+    logging.info('Fitting REL model')
+    rel_model.fit(rel_features, rel_labels)
+
+    return mDoc, ner_model, rel_model
+    
+
+if __name__ == "__main__":
+    argparser = argparse.ArgumentParser()
+    argparser.add_argument('input_text_path', 
+                           help='.txt file of input')
+    argparser.add_argument('input_annotation_path', 
+                           help='.ann file of annotation')
+    argparser.add_argument('--output-dir', 
+                           help='directory to save model files (default `.`')
+    
+    args = argparser.parse_args()
+    
+    # Load text and annotation contents
+    with open(args.input_text_path) as f:
+        text = f.read()
+        
+    with open(args.input_annotation_path) as f:
+        annotation = f.read()
+        
+        
+    mDoc, ner_model, rel_model = annotation2doc(text, annotation)   
+    
+    if args.output_dir:
+        output_dir = args.output_dir
+    else:
+        output_dir = os.path.curdir
+        
+    ner_model_path = os.path.join(output_dir, 'ner_model.pkl')        
+    rel_model_path = os.path.join(output_dir, 'rel_model.pkl')        
+        
+    logging.info('Saving NER model to {}'.format(ner_model_path))
+    with open(ner_model_path, 'wb') as f:
+        pickle.dump(ner_model, f, pickle.HIGHEST_PROTOCOL)
+    
+    logging.info('Saving REL model to {}'.format(rel_model_path))
+    with open(rel_model_path, 'wb') as f:
+        pickle.dump(rel_model, f, pickle.HIGHEST_PROTOCOL)
+    
+    
+
+        
+    
+
+