comparison vamp-client/CapnpRRClient.h @ 96:215c9fb6b7a4

Rename to CapnpRRClient (request-response, as opposed to individual RPC calls)
author Chris Cannam <c.cannam@qmul.ac.uk>
date Thu, 13 Oct 2016 17:00:06 +0100
parents
children 427c4c725085
comparison
equal deleted inserted replaced
95:b6ac26b72b59 96:215c9fb6b7a4
1
2 #ifndef PIPER_CAPNP_CLIENT_H
3 #define PIPER_CAPNP_CLIENT_H
4
5 #include "Loader.h"
6 #include "PluginClient.h"
7 #include "PluginStub.h"
8 #include "SynchronousTransport.h"
9
10 #include "vamp-support/AssignedPluginHandleMapper.h"
11 #include "vamp-capnp/VampnProto.h"
12
13 #include <capnp/serialize.h>
14
15 namespace piper {
16 namespace vampclient {
17
18 class CapnpRRClient : public PluginClient,
19 public Loader
20 {
21 // unsigned to avoid undefined behaviour on possible wrap
22 typedef uint32_t ReqId;
23
24 class CompletenessChecker : public MessageCompletenessChecker {
25 public:
26 bool isComplete(const std::vector<char> &message) const override {
27 auto karr = toKJArray(message);
28 size_t words = karr.size();
29 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr);
30 if (words > expected) {
31 std::cerr << "WARNING: obtained more data than expected ("
32 << words << " " << sizeof(capnp::word)
33 << "-byte words, expected "
34 << expected << ")" << std::endl;
35 }
36 return words >= expected;
37 }
38 };
39
40 public:
41 CapnpRRClient(SynchronousTransport *transport) : //!!! ownership? shared ptr?
42 m_transport(transport),
43 m_completenessChecker(new CompletenessChecker) {
44 transport->setCompletenessChecker(m_completenessChecker);
45 }
46
47 ~CapnpRRClient() {
48 delete m_completenessChecker;
49 }
50
51 //!!! obviously, factor out all repetitive guff
52
53 //!!! list and load are supposed to be called by application code,
54 //!!! but the rest are only supposed to be called by the plugin --
55 //!!! sort out the api here
56
57 // Loader methods:
58
59 Vamp::HostExt::ListResponse
60 listPluginData() override {
61
62 if (!m_transport->isOK()) {
63 throw std::runtime_error("Piper server failed to start");
64 }
65
66 capnp::MallocMessageBuilder message;
67 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
68 VampnProto::buildRpcRequest_List(builder);
69 ReqId id = getId();
70 builder.getId().setNumber(id);
71
72 auto karr = call(message);
73
74 capnp::FlatArrayMessageReader responseMessage(karr);
75 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
76
77 checkResponseType(reader, RpcResponse::Response::Which::LIST, id);
78
79 Vamp::HostExt::ListResponse lr;
80 VampnProto::readListResponse(lr, reader.getResponse().getList());
81 return lr;
82 }
83
84 Vamp::HostExt::LoadResponse
85 loadPlugin(const Vamp::HostExt::LoadRequest &req) override {
86
87 if (!m_transport->isOK()) {
88 throw std::runtime_error("Piper server failed to start");
89 }
90
91 Vamp::HostExt::LoadResponse resp;
92 PluginHandleMapper::Handle handle = serverLoad(req.pluginKey,
93 req.inputSampleRate,
94 req.adapterFlags,
95 resp.staticData,
96 resp.defaultConfiguration);
97
98 Vamp::Plugin *plugin = new PluginStub(this,
99 req.pluginKey,
100 req.inputSampleRate,
101 req.adapterFlags,
102 resp.staticData,
103 resp.defaultConfiguration);
104
105 m_mapper.addPlugin(handle, plugin);
106
107 resp.plugin = plugin;
108 return resp;
109 }
110
111 // PluginClient methods:
112
113 virtual
114 Vamp::Plugin::OutputList
115 configure(PluginStub *plugin,
116 Vamp::HostExt::PluginConfiguration config) override {
117
118 if (!m_transport->isOK()) {
119 throw std::runtime_error("Piper server failed to start");
120 }
121
122 Vamp::HostExt::ConfigurationRequest request;
123 request.plugin = plugin;
124 request.configuration = config;
125
126 capnp::MallocMessageBuilder message;
127 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
128
129 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
130 ReqId id = getId();
131 builder.getId().setNumber(id);
132
133 auto karr = call(message);
134
135 capnp::FlatArrayMessageReader responseMessage(karr);
136 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
137
138 //!!! handle (explicit) error case
139
140 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id);
141
142 Vamp::HostExt::ConfigurationResponse cr;
143 VampnProto::readConfigurationResponse(cr,
144 reader.getResponse().getConfigure(),
145 m_mapper);
146
147 return cr.outputs;
148 };
149
150 virtual
151 Vamp::Plugin::FeatureSet
152 process(PluginStub *plugin,
153 std::vector<std::vector<float> > inputBuffers,
154 Vamp::RealTime timestamp) override {
155
156 if (!m_transport->isOK()) {
157 throw std::runtime_error("Piper server failed to start");
158 }
159
160 Vamp::HostExt::ProcessRequest request;
161 request.plugin = plugin;
162 request.inputBuffers = inputBuffers;
163 request.timestamp = timestamp;
164
165 capnp::MallocMessageBuilder message;
166 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
167 VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
168 ReqId id = getId();
169 builder.getId().setNumber(id);
170
171 auto karr = call(message);
172
173 capnp::FlatArrayMessageReader responseMessage(karr);
174 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
175
176 //!!! handle (explicit) error case
177
178 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id);
179
180 Vamp::HostExt::ProcessResponse pr;
181 VampnProto::readProcessResponse(pr,
182 reader.getResponse().getProcess(),
183 m_mapper);
184
185 return pr.features;
186 }
187
188 virtual Vamp::Plugin::FeatureSet
189 finish(PluginStub *plugin) override {
190
191 if (!m_transport->isOK()) {
192 throw std::runtime_error("Piper server failed to start");
193 }
194
195 Vamp::HostExt::FinishRequest request;
196 request.plugin = plugin;
197
198 capnp::MallocMessageBuilder message;
199 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
200
201 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
202 ReqId id = getId();
203 builder.getId().setNumber(id);
204
205 auto karr = call(message);
206
207 capnp::FlatArrayMessageReader responseMessage(karr);
208 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
209
210 //!!! handle (explicit) error case
211
212 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id);
213
214 Vamp::HostExt::ProcessResponse pr;
215 VampnProto::readFinishResponse(pr,
216 reader.getResponse().getFinish(),
217 m_mapper);
218
219 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
220
221 // Don't delete the plugin. It's the plugin that is supposed
222 // to be calling us here
223
224 return pr.features;
225 }
226
227 virtual void
228 reset(PluginStub *plugin,
229 Vamp::HostExt::PluginConfiguration config) override {
230
231 // Reload the plugin on the server side, and configure it as requested
232
233 if (!m_transport->isOK()) {
234 throw std::runtime_error("Piper server failed to start");
235 }
236
237 if (m_mapper.havePlugin(plugin)) {
238 (void)finish(plugin); // server-side unload
239 }
240
241 Vamp::HostExt::PluginStaticData psd;
242 Vamp::HostExt::PluginConfiguration defaultConfig;
243 PluginHandleMapper::Handle handle =
244 serverLoad(plugin->getPluginKey(),
245 plugin->getInputSampleRate(),
246 plugin->getAdapterFlags(),
247 psd, defaultConfig);
248
249 m_mapper.addPlugin(handle, plugin);
250
251 (void)configure(plugin, config);
252 }
253
254 private:
255 AssignedPluginHandleMapper m_mapper;
256 ReqId getId() {
257 //!!! todo: mutex
258 static ReqId m_nextId = 0;
259 return m_nextId++;
260 }
261
262 static
263 kj::Array<capnp::word>
264 toKJArray(const std::vector<char> &buffer) {
265 // We could do this whole thing with fewer copies, but let's
266 // see whether it matters first
267 size_t wordSize = sizeof(capnp::word);
268 size_t words = buffer.size() / wordSize;
269 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words));
270 memcpy(karr.begin(), buffer.data(), words * wordSize);
271 return karr;
272 }
273
274 void
275 checkResponseType(const RpcResponse::Reader &r,
276 RpcResponse::Response::Which type,
277 ReqId id) {
278
279 if (r.getResponse().which() != type) {
280 throw std::runtime_error("Wrong response type");
281 }
282 if (ReqId(r.getId().getNumber()) != id) {
283 throw std::runtime_error("Wrong response id");
284 }
285 }
286
287 kj::Array<capnp::word>
288 call(capnp::MallocMessageBuilder &message) {
289 auto arr = capnp::messageToFlatArray(message);
290 auto responseBuffer = m_transport->call(arr.asChars().begin(),
291 arr.asChars().size());
292 return toKJArray(responseBuffer);
293 }
294
295 PluginHandleMapper::Handle
296 serverLoad(std::string key, float inputSampleRate, int adapterFlags,
297 Vamp::HostExt::PluginStaticData &psd,
298 Vamp::HostExt::PluginConfiguration &defaultConfig) {
299
300 Vamp::HostExt::LoadRequest request;
301 request.pluginKey = key;
302 request.inputSampleRate = inputSampleRate;
303 request.adapterFlags = adapterFlags;
304
305 capnp::MallocMessageBuilder message;
306 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
307
308 VampnProto::buildRpcRequest_Load(builder, request);
309 ReqId id = getId();
310 builder.getId().setNumber(id);
311
312 auto karr = call(message);
313
314 //!!! ... --> will also need some way to kill this process
315 //!!! (from another thread)
316
317 capnp::FlatArrayMessageReader responseMessage(karr);
318 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
319
320 //!!! handle (explicit) error case
321
322 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
323
324 const LoadResponse::Reader &lr = reader.getResponse().getLoad();
325 VampnProto::readExtractorStaticData(psd, lr.getStaticData());
326 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
327 return lr.getHandle();
328 };
329
330 private:
331 SynchronousTransport *m_transport; //!!! I don't own this, but should I?
332 CompletenessChecker *m_completenessChecker; // I own this
333 };
334
335 }
336 }
337
338 #endif