Mercurial > hg > from-my-pen-to-your-ears-supplementary-material
comparison demo/annotation2model.py @ 0:4dad87badb0c
initial commit
author | Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk> |
---|---|
date | Wed, 16 May 2018 17:56:10 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:4dad87badb0c |
---|---|
1 #!/usr/bin/env python3 | |
2 # -*- coding: utf-8 -*- | |
3 """ | |
4 Created on Mon Apr 30 14:28:49 2018 | |
5 | |
6 @author: Emmanouil Theofanis Chourdakis | |
7 | |
8 Takes a .txt story and an .ann annotation and trains a model. | |
9 | |
10 @output: | |
11 ner .pkl model -- NER recognition model | |
12 rel .pkl model -- RELation extraction model | |
13 | |
14 """ | |
15 | |
16 import os | |
17 | |
18 import argparse | |
19 import logging | |
20 import spacy | |
21 import ner | |
22 import pypeg2 as pg | |
23 import sklearn_crfsuite as crf | |
24 import pickle | |
25 | |
26 logging.basicConfig(level=logging.INFO) | |
27 | |
28 # https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python | |
29 flatten = lambda l: [item for sublist in l for item in sublist] | |
30 | |
31 # Relation Model | |
32 | |
33 from rel import * | |
34 | |
35 def quotes2dict(text): | |
36 new_text = text | |
37 is_open = False | |
38 | |
39 quote_no = 0 | |
40 quote = [] | |
41 narrator = [] | |
42 quote_dict = {} | |
43 | |
44 for n, c in enumerate(text): | |
45 if c == '"' and not is_open: | |
46 is_open = True | |
47 quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator) | |
48 narrator = [] | |
49 quote_no += 1 | |
50 continue | |
51 | |
52 elif c == '"' and is_open: | |
53 is_open = False | |
54 quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote) | |
55 new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no)) | |
56 quote = [] | |
57 quote_no += 1 | |
58 continue | |
59 | |
60 if is_open: | |
61 quote.append(c) | |
62 elif not is_open: | |
63 narrator.append(c) | |
64 | |
65 return new_text, quote_dict | |
66 | |
67 def annotation2doc(text, annotation): | |
68 | |
69 # Load language engine | |
70 logging.info('Loading language engine') | |
71 nlp = spacy.load('en') | |
72 | |
73 # Convert to spacy document type | |
74 logging.info('Parsing to spacy document') | |
75 doc = nlp(text) | |
76 | |
77 # Convert to ner.Document | |
78 logging.info('Converting to custom Document format') | |
79 mDoc = ner.Document(doc) | |
80 | |
81 # Parsing annotation | |
82 logging.info('Parsing annotation') | |
83 parsed = pg.parse(annotation, ner.AnnotationFile) | |
84 | |
85 # Store an entity and relations dictionary since relations | |
86 # point to such entities | |
87 | |
88 dictionary = {} | |
89 | |
90 # Visit all the parsed lines. Do it in two passes, first parse | |
91 # entities and then relations. The reason for that is that some times | |
92 # a relation refers to an entity that has not been defined. | |
93 | |
94 for line in parsed: | |
95 # Every annotation line has a single object | |
96 obj = line[0] | |
97 | |
98 if isinstance(obj, ner.AnnotationTuple): | |
99 | |
100 # If it is a tuple, find the start and end | |
101 # borders, and assign them the appropriate label | |
102 | |
103 start_s, end_s = obj.idx.split() | |
104 start = int(start_s) | |
105 end = int(end_s) | |
106 label = str(obj.type) | |
107 | |
108 # Store to dictionary the string relating | |
109 # to the annotation | |
110 | |
111 dictionary[obj.variable] = mDoc.find_tokens(start, end) | |
112 | |
113 mDoc.assign_label_to_tokens(start, end, label) | |
114 | |
115 for line in parsed: | |
116 # Every annotation line has a single object | |
117 obj = line[0] | |
118 | |
119 if isinstance(obj, ner.RelationTuple): | |
120 | |
121 # Relations have a trigger, a first argument `arg1' and a | |
122 # second argument `arg2'. There are going to be | |
123 # |arg1| * |arg2| relations constructed for each trigger | |
124 # where |arg1| is the number of candidates for argument 1 | |
125 # and |arg2| the number of candidates for argument 2 | |
126 | |
127 arg1_candidates = [] | |
128 arg2_candidates = [] | |
129 | |
130 # Check relation's arguments: | |
131 for arg in obj.args: | |
132 if arg.label == 'Says': | |
133 trigger = dictionary[arg.target] | |
134 label = 'Quote' | |
135 elif arg.label == 'Spatial_Signal': | |
136 trigger = dictionary[arg.target] | |
137 label = 'Spatial_Relation' | |
138 if arg.label in ['Trajector', 'WHO']: | |
139 arg1_candidates.append(dictionary[arg.target]) | |
140 if arg.label in ['Landmark', 'WHAT']: | |
141 arg2_candidates.append(dictionary[arg.target]) | |
142 | |
143 for arg1 in arg1_candidates: | |
144 for arg2 in arg2_candidates: | |
145 mDoc.add_relation(trigger, arg1, arg2, label) | |
146 | |
147 # Create NER model | |
148 logging.info('Creating NER CRF model') | |
149 | |
150 ner_model = crf.CRF(c1=0.1, | |
151 c2=0.1, | |
152 max_iterations=100, | |
153 all_possible_transitions=True) | |
154 | |
155 logging.info('Extracting features/labels from document') | |
156 features, labels = mDoc.get_token_features_labels() | |
157 | |
158 logging.info('Fitting NER model') | |
159 ner_model.fit(features, labels) | |
160 | |
161 # Create Relational model | |
162 logging.info('Creating REL SVM model') | |
163 rel_model = RelModel() | |
164 | |
165 logging.info('Extracting relations features/labels from document') | |
166 rel_features, rel_labels = mDoc.get_candidate_relation_feature_labels() | |
167 | |
168 logging.info('Fitting REL model') | |
169 rel_model.fit(rel_features, rel_labels) | |
170 | |
171 return mDoc, ner_model, rel_model | |
172 | |
173 | |
174 if __name__ == "__main__": | |
175 argparser = argparse.ArgumentParser() | |
176 argparser.add_argument('input_text_path', | |
177 help='.txt file of input') | |
178 argparser.add_argument('input_annotation_path', | |
179 help='.ann file of annotation') | |
180 argparser.add_argument('--output-dir', | |
181 help='directory to save model files (default `.`') | |
182 | |
183 args = argparser.parse_args() | |
184 | |
185 # Load text and annotation contents | |
186 with open(args.input_text_path) as f: | |
187 text = f.read() | |
188 | |
189 with open(args.input_annotation_path) as f: | |
190 annotation = f.read() | |
191 | |
192 | |
193 mDoc, ner_model, rel_model = annotation2doc(text, annotation) | |
194 | |
195 if args.output_dir: | |
196 output_dir = args.output_dir | |
197 else: | |
198 output_dir = os.path.curdir | |
199 | |
200 ner_model_path = os.path.join(output_dir, 'ner_model.pkl') | |
201 rel_model_path = os.path.join(output_dir, 'rel_model.pkl') | |
202 | |
203 logging.info('Saving NER model to {}'.format(ner_model_path)) | |
204 with open(ner_model_path, 'wb') as f: | |
205 pickle.dump(ner_model, f, pickle.HIGHEST_PROTOCOL) | |
206 | |
207 logging.info('Saving REL model to {}'.format(rel_model_path)) | |
208 with open(rel_model_path, 'wb') as f: | |
209 pickle.dump(rel_model, f, pickle.HIGHEST_PROTOCOL) | |
210 | |
211 | |
212 | |
213 | |
214 | |
215 | |
216 |