comparison demo/annotation2model.py @ 0:4dad87badb0c

initial commit
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Wed, 16 May 2018 17:56:10 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:4dad87badb0c
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 """
4 Created on Mon Apr 30 14:28:49 2018
5
6 @author: Emmanouil Theofanis Chourdakis
7
8 Takes a .txt story and an .ann annotation and trains a model.
9
10 @output:
11 ner .pkl model -- NER recognition model
12 rel .pkl model -- RELation extraction model
13
14 """
15
16 import os
17
18 import argparse
19 import logging
20 import spacy
21 import ner
22 import pypeg2 as pg
23 import sklearn_crfsuite as crf
24 import pickle
25
26 logging.basicConfig(level=logging.INFO)
27
28 # https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
29 flatten = lambda l: [item for sublist in l for item in sublist]
30
31 # Relation Model
32
33 from rel import *
34
35 def quotes2dict(text):
36 new_text = text
37 is_open = False
38
39 quote_no = 0
40 quote = []
41 narrator = []
42 quote_dict = {}
43
44 for n, c in enumerate(text):
45 if c == '"' and not is_open:
46 is_open = True
47 quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
48 narrator = []
49 quote_no += 1
50 continue
51
52 elif c == '"' and is_open:
53 is_open = False
54 quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
55 new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
56 quote = []
57 quote_no += 1
58 continue
59
60 if is_open:
61 quote.append(c)
62 elif not is_open:
63 narrator.append(c)
64
65 return new_text, quote_dict
66
67 def annotation2doc(text, annotation):
68
69 # Load language engine
70 logging.info('Loading language engine')
71 nlp = spacy.load('en')
72
73 # Convert to spacy document type
74 logging.info('Parsing to spacy document')
75 doc = nlp(text)
76
77 # Convert to ner.Document
78 logging.info('Converting to custom Document format')
79 mDoc = ner.Document(doc)
80
81 # Parsing annotation
82 logging.info('Parsing annotation')
83 parsed = pg.parse(annotation, ner.AnnotationFile)
84
85 # Store an entity and relations dictionary since relations
86 # point to such entities
87
88 dictionary = {}
89
90 # Visit all the parsed lines. Do it in two passes, first parse
91 # entities and then relations. The reason for that is that some times
92 # a relation refers to an entity that has not been defined.
93
94 for line in parsed:
95 # Every annotation line has a single object
96 obj = line[0]
97
98 if isinstance(obj, ner.AnnotationTuple):
99
100 # If it is a tuple, find the start and end
101 # borders, and assign them the appropriate label
102
103 start_s, end_s = obj.idx.split()
104 start = int(start_s)
105 end = int(end_s)
106 label = str(obj.type)
107
108 # Store to dictionary the string relating
109 # to the annotation
110
111 dictionary[obj.variable] = mDoc.find_tokens(start, end)
112
113 mDoc.assign_label_to_tokens(start, end, label)
114
115 for line in parsed:
116 # Every annotation line has a single object
117 obj = line[0]
118
119 if isinstance(obj, ner.RelationTuple):
120
121 # Relations have a trigger, a first argument `arg1' and a
122 # second argument `arg2'. There are going to be
123 # |arg1| * |arg2| relations constructed for each trigger
124 # where |arg1| is the number of candidates for argument 1
125 # and |arg2| the number of candidates for argument 2
126
127 arg1_candidates = []
128 arg2_candidates = []
129
130 # Check relation's arguments:
131 for arg in obj.args:
132 if arg.label == 'Says':
133 trigger = dictionary[arg.target]
134 label = 'Quote'
135 elif arg.label == 'Spatial_Signal':
136 trigger = dictionary[arg.target]
137 label = 'Spatial_Relation'
138 if arg.label in ['Trajector', 'WHO']:
139 arg1_candidates.append(dictionary[arg.target])
140 if arg.label in ['Landmark', 'WHAT']:
141 arg2_candidates.append(dictionary[arg.target])
142
143 for arg1 in arg1_candidates:
144 for arg2 in arg2_candidates:
145 mDoc.add_relation(trigger, arg1, arg2, label)
146
147 # Create NER model
148 logging.info('Creating NER CRF model')
149
150 ner_model = crf.CRF(c1=0.1,
151 c2=0.1,
152 max_iterations=100,
153 all_possible_transitions=True)
154
155 logging.info('Extracting features/labels from document')
156 features, labels = mDoc.get_token_features_labels()
157
158 logging.info('Fitting NER model')
159 ner_model.fit(features, labels)
160
161 # Create Relational model
162 logging.info('Creating REL SVM model')
163 rel_model = RelModel()
164
165 logging.info('Extracting relations features/labels from document')
166 rel_features, rel_labels = mDoc.get_candidate_relation_feature_labels()
167
168 logging.info('Fitting REL model')
169 rel_model.fit(rel_features, rel_labels)
170
171 return mDoc, ner_model, rel_model
172
173
174 if __name__ == "__main__":
175 argparser = argparse.ArgumentParser()
176 argparser.add_argument('input_text_path',
177 help='.txt file of input')
178 argparser.add_argument('input_annotation_path',
179 help='.ann file of annotation')
180 argparser.add_argument('--output-dir',
181 help='directory to save model files (default `.`')
182
183 args = argparser.parse_args()
184
185 # Load text and annotation contents
186 with open(args.input_text_path) as f:
187 text = f.read()
188
189 with open(args.input_annotation_path) as f:
190 annotation = f.read()
191
192
193 mDoc, ner_model, rel_model = annotation2doc(text, annotation)
194
195 if args.output_dir:
196 output_dir = args.output_dir
197 else:
198 output_dir = os.path.curdir
199
200 ner_model_path = os.path.join(output_dir, 'ner_model.pkl')
201 rel_model_path = os.path.join(output_dir, 'rel_model.pkl')
202
203 logging.info('Saving NER model to {}'.format(ner_model_path))
204 with open(ner_model_path, 'wb') as f:
205 pickle.dump(ner_model, f, pickle.HIGHEST_PROTOCOL)
206
207 logging.info('Saving REL model to {}'.format(rel_model_path))
208 with open(rel_model_path, 'wb') as f:
209 pickle.dump(rel_model, f, pickle.HIGHEST_PROTOCOL)
210
211
212
213
214
215
216