Mercurial > hg > from-my-pen-to-your-ears-supplementary-material
diff demo/annotation2model.py @ 0:4dad87badb0c
initial commit
author | Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk> |
---|---|
date | Wed, 16 May 2018 17:56:10 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/demo/annotation2model.py Wed May 16 17:56:10 2018 +0100 @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Apr 30 14:28:49 2018 + +@author: Emmanouil Theofanis Chourdakis + +Takes a .txt story and an .ann annotation and trains a model. + +@output: + ner .pkl model -- NER recognition model + rel .pkl model -- RELation extraction model + +""" + +import os + +import argparse +import logging +import spacy +import ner +import pypeg2 as pg +import sklearn_crfsuite as crf +import pickle + +logging.basicConfig(level=logging.INFO) + +# https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python +flatten = lambda l: [item for sublist in l for item in sublist] + +# Relation Model + +from rel import * + +def quotes2dict(text): + new_text = text + is_open = False + + quote_no = 0 + quote = [] + narrator = [] + quote_dict = {} + + for n, c in enumerate(text): + if c == '"' and not is_open: + is_open = True + quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator) + narrator = [] + quote_no += 1 + continue + + elif c == '"' and is_open: + is_open = False + quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote) + new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no)) + quote = [] + quote_no += 1 + continue + + if is_open: + quote.append(c) + elif not is_open: + narrator.append(c) + + return new_text, quote_dict + +def annotation2doc(text, annotation): + + # Load language engine + logging.info('Loading language engine') + nlp = spacy.load('en') + + # Convert to spacy document type + logging.info('Parsing to spacy document') + doc = nlp(text) + + # Convert to ner.Document + logging.info('Converting to custom Document format') + mDoc = ner.Document(doc) + + # Parsing annotation + logging.info('Parsing annotation') + parsed = pg.parse(annotation, ner.AnnotationFile) + + # Store an entity and relations dictionary since relations + # point to such entities + + dictionary = {} + + # Visit all the parsed lines. Do it in two passes, first parse + # entities and then relations. The reason for that is that some times + # a relation refers to an entity that has not been defined. + + for line in parsed: + # Every annotation line has a single object + obj = line[0] + + if isinstance(obj, ner.AnnotationTuple): + + # If it is a tuple, find the start and end + # borders, and assign them the appropriate label + + start_s, end_s = obj.idx.split() + start = int(start_s) + end = int(end_s) + label = str(obj.type) + + # Store to dictionary the string relating + # to the annotation + + dictionary[obj.variable] = mDoc.find_tokens(start, end) + + mDoc.assign_label_to_tokens(start, end, label) + + for line in parsed: + # Every annotation line has a single object + obj = line[0] + + if isinstance(obj, ner.RelationTuple): + + # Relations have a trigger, a first argument `arg1' and a + # second argument `arg2'. There are going to be + # |arg1| * |arg2| relations constructed for each trigger + # where |arg1| is the number of candidates for argument 1 + # and |arg2| the number of candidates for argument 2 + + arg1_candidates = [] + arg2_candidates = [] + + # Check relation's arguments: + for arg in obj.args: + if arg.label == 'Says': + trigger = dictionary[arg.target] + label = 'Quote' + elif arg.label == 'Spatial_Signal': + trigger = dictionary[arg.target] + label = 'Spatial_Relation' + if arg.label in ['Trajector', 'WHO']: + arg1_candidates.append(dictionary[arg.target]) + if arg.label in ['Landmark', 'WHAT']: + arg2_candidates.append(dictionary[arg.target]) + + for arg1 in arg1_candidates: + for arg2 in arg2_candidates: + mDoc.add_relation(trigger, arg1, arg2, label) + + # Create NER model + logging.info('Creating NER CRF model') + + ner_model = crf.CRF(c1=0.1, + c2=0.1, + max_iterations=100, + all_possible_transitions=True) + + logging.info('Extracting features/labels from document') + features, labels = mDoc.get_token_features_labels() + + logging.info('Fitting NER model') + ner_model.fit(features, labels) + + # Create Relational model + logging.info('Creating REL SVM model') + rel_model = RelModel() + + logging.info('Extracting relations features/labels from document') + rel_features, rel_labels = mDoc.get_candidate_relation_feature_labels() + + logging.info('Fitting REL model') + rel_model.fit(rel_features, rel_labels) + + return mDoc, ner_model, rel_model + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument('input_text_path', + help='.txt file of input') + argparser.add_argument('input_annotation_path', + help='.ann file of annotation') + argparser.add_argument('--output-dir', + help='directory to save model files (default `.`') + + args = argparser.parse_args() + + # Load text and annotation contents + with open(args.input_text_path) as f: + text = f.read() + + with open(args.input_annotation_path) as f: + annotation = f.read() + + + mDoc, ner_model, rel_model = annotation2doc(text, annotation) + + if args.output_dir: + output_dir = args.output_dir + else: + output_dir = os.path.curdir + + ner_model_path = os.path.join(output_dir, 'ner_model.pkl') + rel_model_path = os.path.join(output_dir, 'rel_model.pkl') + + logging.info('Saving NER model to {}'.format(ner_model_path)) + with open(ner_model_path, 'wb') as f: + pickle.dump(ner_model, f, pickle.HIGHEST_PROTOCOL) + + logging.info('Saving REL model to {}'.format(rel_model_path)) + with open(rel_model_path, 'wb') as f: + pickle.dump(rel_model, f, pickle.HIGHEST_PROTOCOL) + + + + + + +