e@0: #!/usr/bin/env python3 e@0: # -*- coding: utf-8 -*- e@0: """ e@0: Created on Mon Apr 30 14:28:49 2018 e@0: e@0: @author: Emmanouil Theofanis Chourdakis e@0: e@0: Takes a .txt story and an .ann annotation and trains a model. e@0: e@0: @output: e@0: ner .pkl model -- NER recognition model e@0: rel .pkl model -- RELation extraction model e@0: e@0: """ e@0: e@0: import os e@0: e@0: import argparse e@0: import logging e@0: import spacy e@0: import ner e@0: import pypeg2 as pg e@0: import sklearn_crfsuite as crf e@0: import pickle e@0: e@0: logging.basicConfig(level=logging.INFO) e@0: e@0: # https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python e@0: flatten = lambda l: [item for sublist in l for item in sublist] e@0: e@0: # Relation Model e@0: e@0: from rel import * e@0: e@0: def quotes2dict(text): e@0: new_text = text e@0: is_open = False e@0: e@0: quote_no = 0 e@0: quote = [] e@0: narrator = [] e@0: quote_dict = {} e@0: e@0: for n, c in enumerate(text): e@0: if c == '"' and not is_open: e@0: is_open = True e@0: quote_dict[".".format(quote_no)] = ''.join(narrator) e@0: narrator = [] e@0: quote_no += 1 e@0: continue e@0: e@0: elif c == '"' and is_open: e@0: is_open = False e@0: quote_dict[".".format(quote_no)] = ''.join(quote) e@0: new_text = new_text.replace('"'+''.join(quote)+'"', ".".format(quote_no)) e@0: quote = [] e@0: quote_no += 1 e@0: continue e@0: e@0: if is_open: e@0: quote.append(c) e@0: elif not is_open: e@0: narrator.append(c) e@0: e@0: return new_text, quote_dict e@0: e@0: def annotation2doc(text, annotation): e@0: e@0: # Load language engine e@0: logging.info('Loading language engine') e@0: nlp = spacy.load('en') e@0: e@0: # Convert to spacy document type e@0: logging.info('Parsing to spacy document') e@0: doc = nlp(text) e@0: e@0: # Convert to ner.Document e@0: logging.info('Converting to custom Document format') e@0: mDoc = ner.Document(doc) e@0: e@0: # Parsing annotation e@0: logging.info('Parsing annotation') e@0: parsed = pg.parse(annotation, ner.AnnotationFile) e@0: e@0: # Store an entity and relations dictionary since relations e@0: # point to such entities e@0: e@0: dictionary = {} e@0: e@0: # Visit all the parsed lines. Do it in two passes, first parse e@0: # entities and then relations. The reason for that is that some times e@0: # a relation refers to an entity that has not been defined. e@0: e@0: for line in parsed: e@0: # Every annotation line has a single object e@0: obj = line[0] e@0: e@0: if isinstance(obj, ner.AnnotationTuple): e@0: e@0: # If it is a tuple, find the start and end e@0: # borders, and assign them the appropriate label e@0: e@0: start_s, end_s = obj.idx.split() e@0: start = int(start_s) e@0: end = int(end_s) e@0: label = str(obj.type) e@0: e@0: # Store to dictionary the string relating e@0: # to the annotation e@0: e@0: dictionary[obj.variable] = mDoc.find_tokens(start, end) e@0: e@0: mDoc.assign_label_to_tokens(start, end, label) e@0: e@0: for line in parsed: e@0: # Every annotation line has a single object e@0: obj = line[0] e@0: e@0: if isinstance(obj, ner.RelationTuple): e@0: e@0: # Relations have a trigger, a first argument `arg1' and a e@0: # second argument `arg2'. There are going to be e@0: # |arg1| * |arg2| relations constructed for each trigger e@0: # where |arg1| is the number of candidates for argument 1 e@0: # and |arg2| the number of candidates for argument 2 e@0: e@0: arg1_candidates = [] e@0: arg2_candidates = [] e@0: e@0: # Check relation's arguments: e@0: for arg in obj.args: e@0: if arg.label == 'Says': e@0: trigger = dictionary[arg.target] e@0: label = 'Quote' e@0: elif arg.label == 'Spatial_Signal': e@0: trigger = dictionary[arg.target] e@0: label = 'Spatial_Relation' e@0: if arg.label in ['Trajector', 'WHO']: e@0: arg1_candidates.append(dictionary[arg.target]) e@0: if arg.label in ['Landmark', 'WHAT']: e@0: arg2_candidates.append(dictionary[arg.target]) e@0: e@0: for arg1 in arg1_candidates: e@0: for arg2 in arg2_candidates: e@0: mDoc.add_relation(trigger, arg1, arg2, label) e@0: e@0: # Create NER model e@0: logging.info('Creating NER CRF model') e@0: e@0: ner_model = crf.CRF(c1=0.1, e@0: c2=0.1, e@0: max_iterations=100, e@0: all_possible_transitions=True) e@0: e@0: logging.info('Extracting features/labels from document') e@0: features, labels = mDoc.get_token_features_labels() e@0: e@0: logging.info('Fitting NER model') e@0: ner_model.fit(features, labels) e@0: e@0: # Create Relational model e@0: logging.info('Creating REL SVM model') e@0: rel_model = RelModel() e@0: e@0: logging.info('Extracting relations features/labels from document') e@0: rel_features, rel_labels = mDoc.get_candidate_relation_feature_labels() e@0: e@0: logging.info('Fitting REL model') e@0: rel_model.fit(rel_features, rel_labels) e@0: e@0: return mDoc, ner_model, rel_model e@0: e@0: e@0: if __name__ == "__main__": e@0: argparser = argparse.ArgumentParser() e@0: argparser.add_argument('input_text_path', e@0: help='.txt file of input') e@0: argparser.add_argument('input_annotation_path', e@0: help='.ann file of annotation') e@0: argparser.add_argument('--output-dir', e@0: help='directory to save model files (default `.`') e@0: e@0: args = argparser.parse_args() e@0: e@0: # Load text and annotation contents e@0: with open(args.input_text_path) as f: e@0: text = f.read() e@0: e@0: with open(args.input_annotation_path) as f: e@0: annotation = f.read() e@0: e@0: e@0: mDoc, ner_model, rel_model = annotation2doc(text, annotation) e@0: e@0: if args.output_dir: e@0: output_dir = args.output_dir e@0: else: e@0: output_dir = os.path.curdir e@0: e@0: ner_model_path = os.path.join(output_dir, 'ner_model.pkl') e@0: rel_model_path = os.path.join(output_dir, 'rel_model.pkl') e@0: e@0: logging.info('Saving NER model to {}'.format(ner_model_path)) e@0: with open(ner_model_path, 'wb') as f: e@0: pickle.dump(ner_model, f, pickle.HIGHEST_PROTOCOL) e@0: e@0: logging.info('Saving REL model to {}'.format(rel_model_path)) e@0: with open(rel_model_path, 'wb') as f: e@0: pickle.dump(rel_model, f, pickle.HIGHEST_PROTOCOL) e@0: e@0: e@0: e@0: e@0: e@0: e@0: