Mercurial > hg > piper-cpp
comparison vamp-client/CapnpRRClient.h @ 96:215c9fb6b7a4
Rename to CapnpRRClient (request-response, as opposed to individual RPC calls)
author | Chris Cannam <c.cannam@qmul.ac.uk> |
---|---|
date | Thu, 13 Oct 2016 17:00:06 +0100 |
parents | |
children | 427c4c725085 |
comparison
equal
deleted
inserted
replaced
95:b6ac26b72b59 | 96:215c9fb6b7a4 |
---|---|
1 | |
2 #ifndef PIPER_CAPNP_CLIENT_H | |
3 #define PIPER_CAPNP_CLIENT_H | |
4 | |
5 #include "Loader.h" | |
6 #include "PluginClient.h" | |
7 #include "PluginStub.h" | |
8 #include "SynchronousTransport.h" | |
9 | |
10 #include "vamp-support/AssignedPluginHandleMapper.h" | |
11 #include "vamp-capnp/VampnProto.h" | |
12 | |
13 #include <capnp/serialize.h> | |
14 | |
15 namespace piper { | |
16 namespace vampclient { | |
17 | |
18 class CapnpRRClient : public PluginClient, | |
19 public Loader | |
20 { | |
21 // unsigned to avoid undefined behaviour on possible wrap | |
22 typedef uint32_t ReqId; | |
23 | |
24 class CompletenessChecker : public MessageCompletenessChecker { | |
25 public: | |
26 bool isComplete(const std::vector<char> &message) const override { | |
27 auto karr = toKJArray(message); | |
28 size_t words = karr.size(); | |
29 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr); | |
30 if (words > expected) { | |
31 std::cerr << "WARNING: obtained more data than expected (" | |
32 << words << " " << sizeof(capnp::word) | |
33 << "-byte words, expected " | |
34 << expected << ")" << std::endl; | |
35 } | |
36 return words >= expected; | |
37 } | |
38 }; | |
39 | |
40 public: | |
41 CapnpRRClient(SynchronousTransport *transport) : //!!! ownership? shared ptr? | |
42 m_transport(transport), | |
43 m_completenessChecker(new CompletenessChecker) { | |
44 transport->setCompletenessChecker(m_completenessChecker); | |
45 } | |
46 | |
47 ~CapnpRRClient() { | |
48 delete m_completenessChecker; | |
49 } | |
50 | |
51 //!!! obviously, factor out all repetitive guff | |
52 | |
53 //!!! list and load are supposed to be called by application code, | |
54 //!!! but the rest are only supposed to be called by the plugin -- | |
55 //!!! sort out the api here | |
56 | |
57 // Loader methods: | |
58 | |
59 Vamp::HostExt::ListResponse | |
60 listPluginData() override { | |
61 | |
62 if (!m_transport->isOK()) { | |
63 throw std::runtime_error("Piper server failed to start"); | |
64 } | |
65 | |
66 capnp::MallocMessageBuilder message; | |
67 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
68 VampnProto::buildRpcRequest_List(builder); | |
69 ReqId id = getId(); | |
70 builder.getId().setNumber(id); | |
71 | |
72 auto karr = call(message); | |
73 | |
74 capnp::FlatArrayMessageReader responseMessage(karr); | |
75 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
76 | |
77 checkResponseType(reader, RpcResponse::Response::Which::LIST, id); | |
78 | |
79 Vamp::HostExt::ListResponse lr; | |
80 VampnProto::readListResponse(lr, reader.getResponse().getList()); | |
81 return lr; | |
82 } | |
83 | |
84 Vamp::HostExt::LoadResponse | |
85 loadPlugin(const Vamp::HostExt::LoadRequest &req) override { | |
86 | |
87 if (!m_transport->isOK()) { | |
88 throw std::runtime_error("Piper server failed to start"); | |
89 } | |
90 | |
91 Vamp::HostExt::LoadResponse resp; | |
92 PluginHandleMapper::Handle handle = serverLoad(req.pluginKey, | |
93 req.inputSampleRate, | |
94 req.adapterFlags, | |
95 resp.staticData, | |
96 resp.defaultConfiguration); | |
97 | |
98 Vamp::Plugin *plugin = new PluginStub(this, | |
99 req.pluginKey, | |
100 req.inputSampleRate, | |
101 req.adapterFlags, | |
102 resp.staticData, | |
103 resp.defaultConfiguration); | |
104 | |
105 m_mapper.addPlugin(handle, plugin); | |
106 | |
107 resp.plugin = plugin; | |
108 return resp; | |
109 } | |
110 | |
111 // PluginClient methods: | |
112 | |
113 virtual | |
114 Vamp::Plugin::OutputList | |
115 configure(PluginStub *plugin, | |
116 Vamp::HostExt::PluginConfiguration config) override { | |
117 | |
118 if (!m_transport->isOK()) { | |
119 throw std::runtime_error("Piper server failed to start"); | |
120 } | |
121 | |
122 Vamp::HostExt::ConfigurationRequest request; | |
123 request.plugin = plugin; | |
124 request.configuration = config; | |
125 | |
126 capnp::MallocMessageBuilder message; | |
127 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
128 | |
129 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper); | |
130 ReqId id = getId(); | |
131 builder.getId().setNumber(id); | |
132 | |
133 auto karr = call(message); | |
134 | |
135 capnp::FlatArrayMessageReader responseMessage(karr); | |
136 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
137 | |
138 //!!! handle (explicit) error case | |
139 | |
140 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id); | |
141 | |
142 Vamp::HostExt::ConfigurationResponse cr; | |
143 VampnProto::readConfigurationResponse(cr, | |
144 reader.getResponse().getConfigure(), | |
145 m_mapper); | |
146 | |
147 return cr.outputs; | |
148 }; | |
149 | |
150 virtual | |
151 Vamp::Plugin::FeatureSet | |
152 process(PluginStub *plugin, | |
153 std::vector<std::vector<float> > inputBuffers, | |
154 Vamp::RealTime timestamp) override { | |
155 | |
156 if (!m_transport->isOK()) { | |
157 throw std::runtime_error("Piper server failed to start"); | |
158 } | |
159 | |
160 Vamp::HostExt::ProcessRequest request; | |
161 request.plugin = plugin; | |
162 request.inputBuffers = inputBuffers; | |
163 request.timestamp = timestamp; | |
164 | |
165 capnp::MallocMessageBuilder message; | |
166 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
167 VampnProto::buildRpcRequest_Process(builder, request, m_mapper); | |
168 ReqId id = getId(); | |
169 builder.getId().setNumber(id); | |
170 | |
171 auto karr = call(message); | |
172 | |
173 capnp::FlatArrayMessageReader responseMessage(karr); | |
174 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
175 | |
176 //!!! handle (explicit) error case | |
177 | |
178 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id); | |
179 | |
180 Vamp::HostExt::ProcessResponse pr; | |
181 VampnProto::readProcessResponse(pr, | |
182 reader.getResponse().getProcess(), | |
183 m_mapper); | |
184 | |
185 return pr.features; | |
186 } | |
187 | |
188 virtual Vamp::Plugin::FeatureSet | |
189 finish(PluginStub *plugin) override { | |
190 | |
191 if (!m_transport->isOK()) { | |
192 throw std::runtime_error("Piper server failed to start"); | |
193 } | |
194 | |
195 Vamp::HostExt::FinishRequest request; | |
196 request.plugin = plugin; | |
197 | |
198 capnp::MallocMessageBuilder message; | |
199 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
200 | |
201 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper); | |
202 ReqId id = getId(); | |
203 builder.getId().setNumber(id); | |
204 | |
205 auto karr = call(message); | |
206 | |
207 capnp::FlatArrayMessageReader responseMessage(karr); | |
208 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
209 | |
210 //!!! handle (explicit) error case | |
211 | |
212 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id); | |
213 | |
214 Vamp::HostExt::ProcessResponse pr; | |
215 VampnProto::readFinishResponse(pr, | |
216 reader.getResponse().getFinish(), | |
217 m_mapper); | |
218 | |
219 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin)); | |
220 | |
221 // Don't delete the plugin. It's the plugin that is supposed | |
222 // to be calling us here | |
223 | |
224 return pr.features; | |
225 } | |
226 | |
227 virtual void | |
228 reset(PluginStub *plugin, | |
229 Vamp::HostExt::PluginConfiguration config) override { | |
230 | |
231 // Reload the plugin on the server side, and configure it as requested | |
232 | |
233 if (!m_transport->isOK()) { | |
234 throw std::runtime_error("Piper server failed to start"); | |
235 } | |
236 | |
237 if (m_mapper.havePlugin(plugin)) { | |
238 (void)finish(plugin); // server-side unload | |
239 } | |
240 | |
241 Vamp::HostExt::PluginStaticData psd; | |
242 Vamp::HostExt::PluginConfiguration defaultConfig; | |
243 PluginHandleMapper::Handle handle = | |
244 serverLoad(plugin->getPluginKey(), | |
245 plugin->getInputSampleRate(), | |
246 plugin->getAdapterFlags(), | |
247 psd, defaultConfig); | |
248 | |
249 m_mapper.addPlugin(handle, plugin); | |
250 | |
251 (void)configure(plugin, config); | |
252 } | |
253 | |
254 private: | |
255 AssignedPluginHandleMapper m_mapper; | |
256 ReqId getId() { | |
257 //!!! todo: mutex | |
258 static ReqId m_nextId = 0; | |
259 return m_nextId++; | |
260 } | |
261 | |
262 static | |
263 kj::Array<capnp::word> | |
264 toKJArray(const std::vector<char> &buffer) { | |
265 // We could do this whole thing with fewer copies, but let's | |
266 // see whether it matters first | |
267 size_t wordSize = sizeof(capnp::word); | |
268 size_t words = buffer.size() / wordSize; | |
269 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words)); | |
270 memcpy(karr.begin(), buffer.data(), words * wordSize); | |
271 return karr; | |
272 } | |
273 | |
274 void | |
275 checkResponseType(const RpcResponse::Reader &r, | |
276 RpcResponse::Response::Which type, | |
277 ReqId id) { | |
278 | |
279 if (r.getResponse().which() != type) { | |
280 throw std::runtime_error("Wrong response type"); | |
281 } | |
282 if (ReqId(r.getId().getNumber()) != id) { | |
283 throw std::runtime_error("Wrong response id"); | |
284 } | |
285 } | |
286 | |
287 kj::Array<capnp::word> | |
288 call(capnp::MallocMessageBuilder &message) { | |
289 auto arr = capnp::messageToFlatArray(message); | |
290 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
291 arr.asChars().size()); | |
292 return toKJArray(responseBuffer); | |
293 } | |
294 | |
295 PluginHandleMapper::Handle | |
296 serverLoad(std::string key, float inputSampleRate, int adapterFlags, | |
297 Vamp::HostExt::PluginStaticData &psd, | |
298 Vamp::HostExt::PluginConfiguration &defaultConfig) { | |
299 | |
300 Vamp::HostExt::LoadRequest request; | |
301 request.pluginKey = key; | |
302 request.inputSampleRate = inputSampleRate; | |
303 request.adapterFlags = adapterFlags; | |
304 | |
305 capnp::MallocMessageBuilder message; | |
306 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
307 | |
308 VampnProto::buildRpcRequest_Load(builder, request); | |
309 ReqId id = getId(); | |
310 builder.getId().setNumber(id); | |
311 | |
312 auto karr = call(message); | |
313 | |
314 //!!! ... --> will also need some way to kill this process | |
315 //!!! (from another thread) | |
316 | |
317 capnp::FlatArrayMessageReader responseMessage(karr); | |
318 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
319 | |
320 //!!! handle (explicit) error case | |
321 | |
322 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id); | |
323 | |
324 const LoadResponse::Reader &lr = reader.getResponse().getLoad(); | |
325 VampnProto::readExtractorStaticData(psd, lr.getStaticData()); | |
326 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration()); | |
327 return lr.getHandle(); | |
328 }; | |
329 | |
330 private: | |
331 SynchronousTransport *m_transport; //!!! I don't own this, but should I? | |
332 CompletenessChecker *m_completenessChecker; // I own this | |
333 }; | |
334 | |
335 } | |
336 } | |
337 | |
338 #endif |