e@0
|
1 #!/usr/bin/env python3
|
e@0
|
2 # -*- coding: utf-8 -*-
|
e@0
|
3 """
|
e@0
|
4 Created on Mon Apr 30 14:28:49 2018
|
e@0
|
5
|
e@0
|
6 @author: Emmanouil Theofanis Chourdakis
|
e@0
|
7
|
e@0
|
8 Takes a .txt story and an .ann annotation and trains a model.
|
e@0
|
9
|
e@0
|
10 @output:
|
e@0
|
11 ner .pkl model -- NER recognition model
|
e@0
|
12 rel .pkl model -- RELation extraction model
|
e@0
|
13
|
e@0
|
14 """
|
e@0
|
15
|
e@0
|
16 import os
|
e@0
|
17
|
e@0
|
18 import argparse
|
e@0
|
19 import logging
|
e@0
|
20 import spacy
|
e@0
|
21 import ner
|
e@0
|
22 import pypeg2 as pg
|
e@0
|
23 import sklearn_crfsuite as crf
|
e@0
|
24 import pickle
|
e@0
|
25
|
e@0
|
26 logging.basicConfig(level=logging.INFO)
|
e@0
|
27
|
e@0
|
28 # https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
|
e@0
|
29 flatten = lambda l: [item for sublist in l for item in sublist]
|
e@0
|
30
|
e@0
|
31 # Relation Model
|
e@0
|
32
|
e@0
|
33 from rel import *
|
e@0
|
34
|
e@0
|
35 def quotes2dict(text):
|
e@0
|
36 new_text = text
|
e@0
|
37 is_open = False
|
e@0
|
38
|
e@0
|
39 quote_no = 0
|
e@0
|
40 quote = []
|
e@0
|
41 narrator = []
|
e@0
|
42 quote_dict = {}
|
e@0
|
43
|
e@0
|
44 for n, c in enumerate(text):
|
e@0
|
45 if c == '"' and not is_open:
|
e@0
|
46 is_open = True
|
e@0
|
47 quote_dict["<nline{}>.".format(quote_no)] = ''.join(narrator)
|
e@0
|
48 narrator = []
|
e@0
|
49 quote_no += 1
|
e@0
|
50 continue
|
e@0
|
51
|
e@0
|
52 elif c == '"' and is_open:
|
e@0
|
53 is_open = False
|
e@0
|
54 quote_dict["<cline{}>.".format(quote_no)] = ''.join(quote)
|
e@0
|
55 new_text = new_text.replace('"'+''.join(quote)+'"', "<cline{}>.".format(quote_no))
|
e@0
|
56 quote = []
|
e@0
|
57 quote_no += 1
|
e@0
|
58 continue
|
e@0
|
59
|
e@0
|
60 if is_open:
|
e@0
|
61 quote.append(c)
|
e@0
|
62 elif not is_open:
|
e@0
|
63 narrator.append(c)
|
e@0
|
64
|
e@0
|
65 return new_text, quote_dict
|
e@0
|
66
|
e@0
|
67 def annotation2doc(text, annotation):
|
e@0
|
68
|
e@0
|
69 # Load language engine
|
e@0
|
70 logging.info('Loading language engine')
|
e@0
|
71 nlp = spacy.load('en')
|
e@0
|
72
|
e@0
|
73 # Convert to spacy document type
|
e@0
|
74 logging.info('Parsing to spacy document')
|
e@0
|
75 doc = nlp(text)
|
e@0
|
76
|
e@0
|
77 # Convert to ner.Document
|
e@0
|
78 logging.info('Converting to custom Document format')
|
e@0
|
79 mDoc = ner.Document(doc)
|
e@0
|
80
|
e@0
|
81 # Parsing annotation
|
e@0
|
82 logging.info('Parsing annotation')
|
e@0
|
83 parsed = pg.parse(annotation, ner.AnnotationFile)
|
e@0
|
84
|
e@0
|
85 # Store an entity and relations dictionary since relations
|
e@0
|
86 # point to such entities
|
e@0
|
87
|
e@0
|
88 dictionary = {}
|
e@0
|
89
|
e@0
|
90 # Visit all the parsed lines. Do it in two passes, first parse
|
e@0
|
91 # entities and then relations. The reason for that is that some times
|
e@0
|
92 # a relation refers to an entity that has not been defined.
|
e@0
|
93
|
e@0
|
94 for line in parsed:
|
e@0
|
95 # Every annotation line has a single object
|
e@0
|
96 obj = line[0]
|
e@0
|
97
|
e@0
|
98 if isinstance(obj, ner.AnnotationTuple):
|
e@0
|
99
|
e@0
|
100 # If it is a tuple, find the start and end
|
e@0
|
101 # borders, and assign them the appropriate label
|
e@0
|
102
|
e@0
|
103 start_s, end_s = obj.idx.split()
|
e@0
|
104 start = int(start_s)
|
e@0
|
105 end = int(end_s)
|
e@0
|
106 label = str(obj.type)
|
e@0
|
107
|
e@0
|
108 # Store to dictionary the string relating
|
e@0
|
109 # to the annotation
|
e@0
|
110
|
e@0
|
111 dictionary[obj.variable] = mDoc.find_tokens(start, end)
|
e@0
|
112
|
e@0
|
113 mDoc.assign_label_to_tokens(start, end, label)
|
e@0
|
114
|
e@0
|
115 for line in parsed:
|
e@0
|
116 # Every annotation line has a single object
|
e@0
|
117 obj = line[0]
|
e@0
|
118
|
e@0
|
119 if isinstance(obj, ner.RelationTuple):
|
e@0
|
120
|
e@0
|
121 # Relations have a trigger, a first argument `arg1' and a
|
e@0
|
122 # second argument `arg2'. There are going to be
|
e@0
|
123 # |arg1| * |arg2| relations constructed for each trigger
|
e@0
|
124 # where |arg1| is the number of candidates for argument 1
|
e@0
|
125 # and |arg2| the number of candidates for argument 2
|
e@0
|
126
|
e@0
|
127 arg1_candidates = []
|
e@0
|
128 arg2_candidates = []
|
e@0
|
129
|
e@0
|
130 # Check relation's arguments:
|
e@0
|
131 for arg in obj.args:
|
e@0
|
132 if arg.label == 'Says':
|
e@0
|
133 trigger = dictionary[arg.target]
|
e@0
|
134 label = 'Quote'
|
e@0
|
135 elif arg.label == 'Spatial_Signal':
|
e@0
|
136 trigger = dictionary[arg.target]
|
e@0
|
137 label = 'Spatial_Relation'
|
e@0
|
138 if arg.label in ['Trajector', 'WHO']:
|
e@0
|
139 arg1_candidates.append(dictionary[arg.target])
|
e@0
|
140 if arg.label in ['Landmark', 'WHAT']:
|
e@0
|
141 arg2_candidates.append(dictionary[arg.target])
|
e@0
|
142
|
e@0
|
143 for arg1 in arg1_candidates:
|
e@0
|
144 for arg2 in arg2_candidates:
|
e@0
|
145 mDoc.add_relation(trigger, arg1, arg2, label)
|
e@0
|
146
|
e@0
|
147 # Create NER model
|
e@0
|
148 logging.info('Creating NER CRF model')
|
e@0
|
149
|
e@0
|
150 ner_model = crf.CRF(c1=0.1,
|
e@0
|
151 c2=0.1,
|
e@0
|
152 max_iterations=100,
|
e@0
|
153 all_possible_transitions=True)
|
e@0
|
154
|
e@0
|
155 logging.info('Extracting features/labels from document')
|
e@0
|
156 features, labels = mDoc.get_token_features_labels()
|
e@0
|
157
|
e@0
|
158 logging.info('Fitting NER model')
|
e@0
|
159 ner_model.fit(features, labels)
|
e@0
|
160
|
e@0
|
161 # Create Relational model
|
e@0
|
162 logging.info('Creating REL SVM model')
|
e@0
|
163 rel_model = RelModel()
|
e@0
|
164
|
e@0
|
165 logging.info('Extracting relations features/labels from document')
|
e@0
|
166 rel_features, rel_labels = mDoc.get_candidate_relation_feature_labels()
|
e@0
|
167
|
e@0
|
168 logging.info('Fitting REL model')
|
e@0
|
169 rel_model.fit(rel_features, rel_labels)
|
e@0
|
170
|
e@0
|
171 return mDoc, ner_model, rel_model
|
e@0
|
172
|
e@0
|
173
|
e@0
|
174 if __name__ == "__main__":
|
e@0
|
175 argparser = argparse.ArgumentParser()
|
e@0
|
176 argparser.add_argument('input_text_path',
|
e@0
|
177 help='.txt file of input')
|
e@0
|
178 argparser.add_argument('input_annotation_path',
|
e@0
|
179 help='.ann file of annotation')
|
e@0
|
180 argparser.add_argument('--output-dir',
|
e@0
|
181 help='directory to save model files (default `.`')
|
e@0
|
182
|
e@0
|
183 args = argparser.parse_args()
|
e@0
|
184
|
e@0
|
185 # Load text and annotation contents
|
e@0
|
186 with open(args.input_text_path) as f:
|
e@0
|
187 text = f.read()
|
e@0
|
188
|
e@0
|
189 with open(args.input_annotation_path) as f:
|
e@0
|
190 annotation = f.read()
|
e@0
|
191
|
e@0
|
192
|
e@0
|
193 mDoc, ner_model, rel_model = annotation2doc(text, annotation)
|
e@0
|
194
|
e@0
|
195 if args.output_dir:
|
e@0
|
196 output_dir = args.output_dir
|
e@0
|
197 else:
|
e@0
|
198 output_dir = os.path.curdir
|
e@0
|
199
|
e@0
|
200 ner_model_path = os.path.join(output_dir, 'ner_model.pkl')
|
e@0
|
201 rel_model_path = os.path.join(output_dir, 'rel_model.pkl')
|
e@0
|
202
|
e@0
|
203 logging.info('Saving NER model to {}'.format(ner_model_path))
|
e@0
|
204 with open(ner_model_path, 'wb') as f:
|
e@0
|
205 pickle.dump(ner_model, f, pickle.HIGHEST_PROTOCOL)
|
e@0
|
206
|
e@0
|
207 logging.info('Saving REL model to {}'.format(rel_model_path))
|
e@0
|
208 with open(rel_model_path, 'wb') as f:
|
e@0
|
209 pickle.dump(rel_model, f, pickle.HIGHEST_PROTOCOL)
|
e@0
|
210
|
e@0
|
211
|
e@0
|
212
|
e@0
|
213
|
e@0
|
214
|
e@0
|
215
|
e@0
|
216
|