annotate demo/annotation2model.py @ 13:16066f0a7127 tip

fixed the problem with brat
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Sat, 08 Dec 2018 11:02:40 +0000
parents 90155bdd5dd6
children
rev   line source
e@0 1 #!/usr/bin/env python3
e@0 2 # -*- coding: utf-8 -*-
e@0 3 """
e@0 4 Created on Mon Apr 30 14:28:49 2018
e@0 5
e@0 6 @author: Emmanouil Theofanis Chourdakis
e@0 7
e@0 8 Takes a .txt story and an .ann annotation and trains a model.
e@0 9
e@0 10 @output:
e@0 11 ner .pkl model -- NER recognition model
e@0 12 rel .pkl model -- RELation extraction model
e@0 13
e@0 14 """
e@0 15
e@0 16 import os
e@0 17
e@0 18 import argparse
e@0 19 import logging
e@0 20 import spacy
e@0 21 import ner
e@0 22 import pypeg2 as pg
e@0 23 import sklearn_crfsuite as crf
e@0 24 import pickle
e@0 25
e@0 26 logging.basicConfig(level=logging.INFO)
e@0 27
e@0 28 # https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
e@0 29 flatten = lambda l: [item for sublist in l for item in sublist]
e@0 30
e@0 31 # Relation Model
e@0 32
e@0 33 from rel import *
e@0 34
e@0 35 def quotes2dict(text):
e@0 36 new_text = text
e@0 37 is_open = False
e@0 38
e@0 39 quote_no = 0
e@0 40 quote = []
e@0 41 narrator = []
e@0 42 quote_dict = {}
e@0 43
e@0 44 for n, c in enumerate(text):
e@0 45 if c == '"' and not is_open:
e@0 46 is_open = True
e@0 47 quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
e@0 48 narrator = []
e@0 49 quote_no += 1
e@0 50 continue
e@0 51
e@0 52 elif c == '"' and is_open:
e@0 53 is_open = False
e@0 54 quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
e@0 55 new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
e@0 56 quote = []
e@0 57 quote_no += 1
e@0 58 continue
e@0 59
e@0 60 if is_open:
e@0 61 quote.append(c)
e@0 62 elif not is_open:
e@0 63 narrator.append(c)
e@0 64
e@0 65 return new_text, quote_dict
e@0 66
e@0 67 def annotation2doc(text, annotation):
e@0 68
e@0 69 # Load language engine
e@0 70 logging.info('Loading language engine')
e@0 71 nlp = spacy.load('en')
e@0 72
e@0 73 # Convert to spacy document type
e@0 74 logging.info('Parsing to spacy document')
e@0 75 doc = nlp(text)
e@0 76
e@0 77 # Convert to ner.Document
e@0 78 logging.info('Converting to custom Document format')
e@0 79 mDoc = ner.Document(doc)
e@0 80
e@0 81 # Parsing annotation
e@0 82 logging.info('Parsing annotation')
e@0 83 parsed = pg.parse(annotation, ner.AnnotationFile)
e@0 84
e@0 85 # Store an entity and relations dictionary since relations
e@0 86 # point to such entities
e@0 87
e@0 88 dictionary = {}
e@0 89
e@0 90 # Visit all the parsed lines. Do it in two passes, first parse
e@0 91 # entities and then relations. The reason for that is that some times
e@0 92 # a relation refers to an entity that has not been defined.
e@0 93
e@0 94 for line in parsed:
e@0 95 # Every annotation line has a single object
e@0 96 obj = line[0]
e@0 97
e@0 98 if isinstance(obj, ner.AnnotationTuple):
e@0 99
e@0 100 # If it is a tuple, find the start and end
e@0 101 # borders, and assign them the appropriate label
e@0 102
e@0 103 start_s, end_s = obj.idx.split()
e@0 104 start = int(start_s)
e@0 105 end = int(end_s)
e@0 106 label = str(obj.type)
e@0 107
e@0 108 # Store to dictionary the string relating
e@0 109 # to the annotation
e@0 110
e@0 111 dictionary[obj.variable] = mDoc.find_tokens(start, end)
e@0 112
e@0 113 mDoc.assign_label_to_tokens(start, end, label)
e@0 114
e@0 115 for line in parsed:
e@0 116 # Every annotation line has a single object
e@0 117 obj = line[0]
e@0 118
e@0 119 if isinstance(obj, ner.RelationTuple):
e@0 120
e@0 121 # Relations have a trigger, a first argument `arg1' and a
e@0 122 # second argument `arg2'. There are going to be
e@0 123 # |arg1| * |arg2| relations constructed for each trigger
e@0 124 # where |arg1| is the number of candidates for argument 1
e@0 125 # and |arg2| the number of candidates for argument 2
e@0 126
e@0 127 arg1_candidates = []
e@0 128 arg2_candidates = []
e@0 129
e@0 130 # Check relation's arguments:
e@0 131 for arg in obj.args:
e@0 132 if arg.label == 'Says':
e@0 133 trigger = dictionary[arg.target]
e@0 134 label = 'Quote'
e@0 135 elif arg.label == 'Spatial_Signal':
e@0 136 trigger = dictionary[arg.target]
e@0 137 label = 'Spatial_Relation'
e@0 138 if arg.label in ['Trajector', 'WHO']:
e@0 139 arg1_candidates.append(dictionary[arg.target])
e@0 140 if arg.label in ['Landmark', 'WHAT']:
e@0 141 arg2_candidates.append(dictionary[arg.target])
e@0 142
e@0 143 for arg1 in arg1_candidates:
e@0 144 for arg2 in arg2_candidates:
e@0 145 mDoc.add_relation(trigger, arg1, arg2, label)
e@0 146
e@0 147 # Create NER model
e@0 148 logging.info('Creating NER CRF model')
e@0 149
e@0 150 ner_model = crf.CRF(c1=0.1,
e@0 151 c2=0.1,
e@0 152 max_iterations=100,
e@0 153 all_possible_transitions=True)
e@0 154
e@0 155 logging.info('Extracting features/labels from document')
e@0 156 features, labels = mDoc.get_token_features_labels()
e@0 157
e@0 158 logging.info('Fitting NER model')
e@0 159 ner_model.fit(features, labels)
e@0 160
e@0 161 # Create Relational model
e@0 162 logging.info('Creating REL SVM model')
e@0 163 rel_model = RelModel()
e@0 164
e@0 165 logging.info('Extracting relations features/labels from document')
e@0 166 rel_features, rel_labels = mDoc.get_candidate_relation_feature_labels()
e@0 167
e@0 168 logging.info('Fitting REL model')
e@0 169 rel_model.fit(rel_features, rel_labels)
e@0 170
e@0 171 return mDoc, ner_model, rel_model
e@0 172
e@0 173
e@0 174 if __name__ == "__main__":
e@0 175 argparser = argparse.ArgumentParser()
e@0 176 argparser.add_argument('input_text_path',
e@0 177 help='.txt file of input')
e@0 178 argparser.add_argument('input_annotation_path',
e@0 179 help='.ann file of annotation')
e@0 180 argparser.add_argument('--output-dir',
e@0 181 help='directory to save model files (default `.`')
e@0 182
e@0 183 args = argparser.parse_args()
e@0 184
e@0 185 # Load text and annotation contents
e@0 186 with open(args.input_text_path) as f:
e@0 187 text = f.read()
e@0 188
e@0 189 with open(args.input_annotation_path) as f:
e@0 190 annotation = f.read()
e@0 191
e@0 192
e@0 193 mDoc, ner_model, rel_model = annotation2doc(text, annotation)
e@0 194
e@0 195 if args.output_dir:
e@0 196 output_dir = args.output_dir
e@0 197 else:
e@0 198 output_dir = os.path.curdir
e@0 199
e@0 200 ner_model_path = os.path.join(output_dir, 'ner_model.pkl')
e@0 201 rel_model_path = os.path.join(output_dir, 'rel_model.pkl')
e@0 202
e@0 203 logging.info('Saving NER model to {}'.format(ner_model_path))
e@0 204 with open(ner_model_path, 'wb') as f:
e@0 205 pickle.dump(ner_model, f, pickle.HIGHEST_PROTOCOL)
e@0 206
e@0 207 logging.info('Saving REL model to {}'.format(rel_model_path))
e@0 208 with open(rel_model_path, 'wb') as f:
e@0 209 pickle.dump(rel_model, f, pickle.HIGHEST_PROTOCOL)
e@0 210
e@0 211
e@0 212
e@0 213
e@0 214
e@0 215
e@0 216