cannam@111
|
1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
|
c@96
|
2
|
c@96
|
3 #ifndef PIPER_CAPNP_CLIENT_H
|
c@96
|
4 #define PIPER_CAPNP_CLIENT_H
|
c@96
|
5
|
c@96
|
6 #include "Loader.h"
|
c@96
|
7 #include "PluginClient.h"
|
c@96
|
8 #include "PluginStub.h"
|
c@96
|
9 #include "SynchronousTransport.h"
|
c@96
|
10
|
c@96
|
11 #include "vamp-support/AssignedPluginHandleMapper.h"
|
c@96
|
12 #include "vamp-capnp/VampnProto.h"
|
c@96
|
13
|
c@96
|
14 #include <capnp/serialize.h>
|
c@96
|
15
|
c@97
|
16 namespace piper_vamp {
|
c@97
|
17 namespace client {
|
c@96
|
18
|
c@100
|
19 /**
|
c@100
|
20 * Client for a request-response Piper server, i.e. using the
|
c@100
|
21 * RpcRequest/RpcResponse structures with a single process call rather
|
c@100
|
22 * than having individual RPC methods, with a synchronous transport
|
c@100
|
23 * such as a subprocess pipe arrangement. Only one request can be
|
c@100
|
24 * handled at a time. This class is thread-safe if and only if it is
|
c@100
|
25 * constructed with a thread-safe SynchronousTransport implementation.
|
c@100
|
26 */
|
c@96
|
27 class CapnpRRClient : public PluginClient,
|
c@97
|
28 public Loader
|
c@96
|
29 {
|
c@96
|
30 // unsigned to avoid undefined behaviour on possible wrap
|
c@96
|
31 typedef uint32_t ReqId;
|
c@96
|
32
|
c@96
|
33 class CompletenessChecker : public MessageCompletenessChecker {
|
c@96
|
34 public:
|
c@96
|
35 bool isComplete(const std::vector<char> &message) const override {
|
c@96
|
36 auto karr = toKJArray(message);
|
c@96
|
37 size_t words = karr.size();
|
c@96
|
38 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr);
|
c@96
|
39 if (words > expected) {
|
c@96
|
40 std::cerr << "WARNING: obtained more data than expected ("
|
c@96
|
41 << words << " " << sizeof(capnp::word)
|
c@96
|
42 << "-byte words, expected "
|
c@96
|
43 << expected << ")" << std::endl;
|
c@96
|
44 }
|
c@96
|
45 return words >= expected;
|
c@96
|
46 }
|
c@96
|
47 };
|
c@96
|
48
|
c@96
|
49 public:
|
c@96
|
50 CapnpRRClient(SynchronousTransport *transport) : //!!! ownership? shared ptr?
|
c@96
|
51 m_transport(transport),
|
c@96
|
52 m_completenessChecker(new CompletenessChecker) {
|
c@96
|
53 transport->setCompletenessChecker(m_completenessChecker);
|
c@96
|
54 }
|
c@96
|
55
|
c@96
|
56 ~CapnpRRClient() {
|
c@96
|
57 delete m_completenessChecker;
|
c@96
|
58 }
|
c@96
|
59
|
c@96
|
60 //!!! obviously, factor out all repetitive guff
|
c@96
|
61
|
c@96
|
62 //!!! list and load are supposed to be called by application code,
|
c@96
|
63 //!!! but the rest are only supposed to be called by the plugin --
|
c@96
|
64 //!!! sort out the api here
|
c@96
|
65
|
c@96
|
66 // Loader methods:
|
c@96
|
67
|
c@97
|
68 ListResponse
|
c@96
|
69 listPluginData() override {
|
c@96
|
70
|
c@96
|
71 if (!m_transport->isOK()) {
|
c@96
|
72 throw std::runtime_error("Piper server failed to start");
|
c@96
|
73 }
|
c@96
|
74
|
c@96
|
75 capnp::MallocMessageBuilder message;
|
c@97
|
76 piper::RpcRequest::Builder builder = message.initRoot<piper::RpcRequest>();
|
c@96
|
77 VampnProto::buildRpcRequest_List(builder);
|
c@96
|
78 ReqId id = getId();
|
c@96
|
79 builder.getId().setNumber(id);
|
c@96
|
80
|
c@96
|
81 auto karr = call(message);
|
c@96
|
82
|
c@96
|
83 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@97
|
84 piper::RpcResponse::Reader reader = responseMessage.getRoot<piper::RpcResponse>();
|
c@96
|
85
|
c@97
|
86 checkResponseType(reader, piper::RpcResponse::Response::Which::LIST, id);
|
c@96
|
87
|
c@97
|
88 ListResponse lr;
|
c@96
|
89 VampnProto::readListResponse(lr, reader.getResponse().getList());
|
c@96
|
90 return lr;
|
c@96
|
91 }
|
c@96
|
92
|
c@97
|
93 LoadResponse
|
c@97
|
94 loadPlugin(const LoadRequest &req) override {
|
c@96
|
95
|
c@96
|
96 if (!m_transport->isOK()) {
|
c@96
|
97 throw std::runtime_error("Piper server failed to start");
|
c@96
|
98 }
|
c@96
|
99
|
c@97
|
100 LoadResponse resp;
|
c@96
|
101 PluginHandleMapper::Handle handle = serverLoad(req.pluginKey,
|
c@96
|
102 req.inputSampleRate,
|
c@96
|
103 req.adapterFlags,
|
c@96
|
104 resp.staticData,
|
c@96
|
105 resp.defaultConfiguration);
|
c@96
|
106
|
c@96
|
107 Vamp::Plugin *plugin = new PluginStub(this,
|
c@96
|
108 req.pluginKey,
|
c@96
|
109 req.inputSampleRate,
|
c@96
|
110 req.adapterFlags,
|
c@96
|
111 resp.staticData,
|
c@96
|
112 resp.defaultConfiguration);
|
c@96
|
113
|
c@96
|
114 m_mapper.addPlugin(handle, plugin);
|
c@96
|
115
|
c@96
|
116 resp.plugin = plugin;
|
c@96
|
117 return resp;
|
c@96
|
118 }
|
c@96
|
119
|
c@96
|
120 // PluginClient methods:
|
c@96
|
121
|
c@96
|
122 virtual
|
c@96
|
123 Vamp::Plugin::OutputList
|
c@96
|
124 configure(PluginStub *plugin,
|
c@97
|
125 PluginConfiguration config) override {
|
c@96
|
126
|
c@96
|
127 if (!m_transport->isOK()) {
|
c@96
|
128 throw std::runtime_error("Piper server failed to start");
|
c@96
|
129 }
|
c@96
|
130
|
c@97
|
131 ConfigurationRequest request;
|
c@96
|
132 request.plugin = plugin;
|
c@96
|
133 request.configuration = config;
|
c@96
|
134
|
c@96
|
135 capnp::MallocMessageBuilder message;
|
c@97
|
136 piper::RpcRequest::Builder builder = message.initRoot<piper::RpcRequest>();
|
c@96
|
137
|
c@96
|
138 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
|
c@96
|
139 ReqId id = getId();
|
c@96
|
140 builder.getId().setNumber(id);
|
c@96
|
141
|
c@96
|
142 auto karr = call(message);
|
c@96
|
143
|
c@96
|
144 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@97
|
145 piper::RpcResponse::Reader reader = responseMessage.getRoot<piper::RpcResponse>();
|
c@96
|
146
|
c@96
|
147 //!!! handle (explicit) error case
|
c@96
|
148
|
c@97
|
149 checkResponseType(reader, piper::RpcResponse::Response::Which::CONFIGURE, id);
|
c@96
|
150
|
c@97
|
151 ConfigurationResponse cr;
|
c@96
|
152 VampnProto::readConfigurationResponse(cr,
|
c@96
|
153 reader.getResponse().getConfigure(),
|
c@96
|
154 m_mapper);
|
c@96
|
155
|
c@96
|
156 return cr.outputs;
|
c@96
|
157 };
|
c@96
|
158
|
c@96
|
159 virtual
|
c@96
|
160 Vamp::Plugin::FeatureSet
|
c@96
|
161 process(PluginStub *plugin,
|
c@96
|
162 std::vector<std::vector<float> > inputBuffers,
|
c@96
|
163 Vamp::RealTime timestamp) override {
|
c@96
|
164
|
c@96
|
165 if (!m_transport->isOK()) {
|
c@96
|
166 throw std::runtime_error("Piper server failed to start");
|
c@96
|
167 }
|
c@96
|
168
|
c@97
|
169 ProcessRequest request;
|
c@96
|
170 request.plugin = plugin;
|
c@96
|
171 request.inputBuffers = inputBuffers;
|
c@96
|
172 request.timestamp = timestamp;
|
c@96
|
173
|
c@96
|
174 capnp::MallocMessageBuilder message;
|
c@97
|
175 piper::RpcRequest::Builder builder = message.initRoot<piper::RpcRequest>();
|
c@96
|
176 VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
|
c@96
|
177 ReqId id = getId();
|
c@96
|
178 builder.getId().setNumber(id);
|
c@96
|
179
|
c@96
|
180 auto karr = call(message);
|
c@96
|
181
|
c@96
|
182 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@97
|
183 piper::RpcResponse::Reader reader = responseMessage.getRoot<piper::RpcResponse>();
|
c@96
|
184
|
c@96
|
185 //!!! handle (explicit) error case
|
c@96
|
186
|
c@97
|
187 checkResponseType(reader, piper::RpcResponse::Response::Which::PROCESS, id);
|
c@96
|
188
|
c@97
|
189 ProcessResponse pr;
|
c@96
|
190 VampnProto::readProcessResponse(pr,
|
c@96
|
191 reader.getResponse().getProcess(),
|
c@96
|
192 m_mapper);
|
c@96
|
193
|
c@96
|
194 return pr.features;
|
c@96
|
195 }
|
c@96
|
196
|
c@96
|
197 virtual Vamp::Plugin::FeatureSet
|
c@96
|
198 finish(PluginStub *plugin) override {
|
c@96
|
199
|
c@96
|
200 if (!m_transport->isOK()) {
|
c@96
|
201 throw std::runtime_error("Piper server failed to start");
|
c@96
|
202 }
|
c@96
|
203
|
c@97
|
204 FinishRequest request;
|
c@96
|
205 request.plugin = plugin;
|
c@96
|
206
|
c@96
|
207 capnp::MallocMessageBuilder message;
|
c@97
|
208 piper::RpcRequest::Builder builder = message.initRoot<piper::RpcRequest>();
|
c@96
|
209
|
c@96
|
210 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
|
c@96
|
211 ReqId id = getId();
|
c@96
|
212 builder.getId().setNumber(id);
|
c@96
|
213
|
c@96
|
214 auto karr = call(message);
|
c@96
|
215
|
c@96
|
216 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@97
|
217 piper::RpcResponse::Reader reader = responseMessage.getRoot<piper::RpcResponse>();
|
c@96
|
218
|
c@96
|
219 //!!! handle (explicit) error case
|
c@96
|
220
|
c@97
|
221 checkResponseType(reader, piper::RpcResponse::Response::Which::FINISH, id);
|
c@96
|
222
|
c@97
|
223 FinishResponse pr;
|
c@96
|
224 VampnProto::readFinishResponse(pr,
|
c@96
|
225 reader.getResponse().getFinish(),
|
c@96
|
226 m_mapper);
|
c@96
|
227
|
c@96
|
228 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
|
c@96
|
229
|
c@96
|
230 // Don't delete the plugin. It's the plugin that is supposed
|
c@96
|
231 // to be calling us here
|
c@96
|
232
|
c@96
|
233 return pr.features;
|
c@96
|
234 }
|
c@96
|
235
|
c@96
|
236 virtual void
|
c@96
|
237 reset(PluginStub *plugin,
|
c@97
|
238 PluginConfiguration config) override {
|
c@96
|
239
|
c@96
|
240 // Reload the plugin on the server side, and configure it as requested
|
c@96
|
241
|
c@96
|
242 if (!m_transport->isOK()) {
|
c@96
|
243 throw std::runtime_error("Piper server failed to start");
|
c@96
|
244 }
|
c@96
|
245
|
c@96
|
246 if (m_mapper.havePlugin(plugin)) {
|
c@96
|
247 (void)finish(plugin); // server-side unload
|
c@96
|
248 }
|
c@96
|
249
|
c@97
|
250 PluginStaticData psd;
|
c@97
|
251 PluginConfiguration defaultConfig;
|
c@96
|
252 PluginHandleMapper::Handle handle =
|
c@96
|
253 serverLoad(plugin->getPluginKey(),
|
c@96
|
254 plugin->getInputSampleRate(),
|
c@96
|
255 plugin->getAdapterFlags(),
|
c@96
|
256 psd, defaultConfig);
|
c@96
|
257
|
c@96
|
258 m_mapper.addPlugin(handle, plugin);
|
c@96
|
259
|
c@96
|
260 (void)configure(plugin, config);
|
c@96
|
261 }
|
c@96
|
262
|
c@96
|
263 private:
|
c@96
|
264 AssignedPluginHandleMapper m_mapper;
|
c@96
|
265 ReqId getId() {
|
c@96
|
266 //!!! todo: mutex
|
c@96
|
267 static ReqId m_nextId = 0;
|
c@96
|
268 return m_nextId++;
|
c@96
|
269 }
|
c@96
|
270
|
c@96
|
271 static
|
c@96
|
272 kj::Array<capnp::word>
|
c@96
|
273 toKJArray(const std::vector<char> &buffer) {
|
c@96
|
274 // We could do this whole thing with fewer copies, but let's
|
c@96
|
275 // see whether it matters first
|
c@96
|
276 size_t wordSize = sizeof(capnp::word);
|
c@96
|
277 size_t words = buffer.size() / wordSize;
|
c@96
|
278 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words));
|
c@96
|
279 memcpy(karr.begin(), buffer.data(), words * wordSize);
|
c@96
|
280 return karr;
|
c@96
|
281 }
|
c@96
|
282
|
c@96
|
283 void
|
c@97
|
284 checkResponseType(const piper::RpcResponse::Reader &r,
|
c@97
|
285 piper::RpcResponse::Response::Which type,
|
c@96
|
286 ReqId id) {
|
c@96
|
287
|
c@96
|
288 if (r.getResponse().which() != type) {
|
c@101
|
289 std::cerr << "checkResponseType: wrong response type (received "
|
cannam@111
|
290 << int(r.getResponse().which()) << ", expected "
|
cannam@111
|
291 << int(type) << ")"
|
c@101
|
292 << std::endl;
|
c@96
|
293 throw std::runtime_error("Wrong response type");
|
c@96
|
294 }
|
c@96
|
295 if (ReqId(r.getId().getNumber()) != id) {
|
c@101
|
296 std::cerr << "checkResponseType: wrong response id (received "
|
c@101
|
297 << r.getId().getNumber() << ", expected " << id << ")"
|
c@101
|
298 << std::endl;
|
c@96
|
299 throw std::runtime_error("Wrong response id");
|
c@96
|
300 }
|
c@96
|
301 }
|
c@96
|
302
|
c@96
|
303 kj::Array<capnp::word>
|
c@96
|
304 call(capnp::MallocMessageBuilder &message) {
|
c@96
|
305 auto arr = capnp::messageToFlatArray(message);
|
c@96
|
306 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@96
|
307 arr.asChars().size());
|
c@96
|
308 return toKJArray(responseBuffer);
|
c@96
|
309 }
|
c@96
|
310
|
c@96
|
311 PluginHandleMapper::Handle
|
c@96
|
312 serverLoad(std::string key, float inputSampleRate, int adapterFlags,
|
c@97
|
313 PluginStaticData &psd,
|
c@97
|
314 PluginConfiguration &defaultConfig) {
|
c@96
|
315
|
c@97
|
316 LoadRequest request;
|
c@96
|
317 request.pluginKey = key;
|
c@96
|
318 request.inputSampleRate = inputSampleRate;
|
c@96
|
319 request.adapterFlags = adapterFlags;
|
c@96
|
320
|
c@96
|
321 capnp::MallocMessageBuilder message;
|
c@97
|
322 piper::RpcRequest::Builder builder = message.initRoot<piper::RpcRequest>();
|
c@96
|
323
|
c@96
|
324 VampnProto::buildRpcRequest_Load(builder, request);
|
c@96
|
325 ReqId id = getId();
|
c@96
|
326 builder.getId().setNumber(id);
|
c@96
|
327
|
c@96
|
328 auto karr = call(message);
|
c@96
|
329
|
c@96
|
330 //!!! ... --> will also need some way to kill this process
|
c@96
|
331 //!!! (from another thread)
|
c@96
|
332
|
c@96
|
333 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@97
|
334 piper::RpcResponse::Reader reader = responseMessage.getRoot<piper::RpcResponse>();
|
c@96
|
335
|
c@96
|
336 //!!! handle (explicit) error case
|
c@96
|
337
|
c@97
|
338 checkResponseType(reader, piper::RpcResponse::Response::Which::LOAD, id);
|
c@96
|
339
|
c@97
|
340 const piper::LoadResponse::Reader &lr = reader.getResponse().getLoad();
|
c@96
|
341 VampnProto::readExtractorStaticData(psd, lr.getStaticData());
|
c@96
|
342 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
|
c@96
|
343 return lr.getHandle();
|
c@96
|
344 };
|
c@96
|
345
|
c@96
|
346 private:
|
c@96
|
347 SynchronousTransport *m_transport; //!!! I don't own this, but should I?
|
c@96
|
348 CompletenessChecker *m_completenessChecker; // I own this
|
c@96
|
349 };
|
c@96
|
350
|
c@96
|
351 }
|
c@96
|
352 }
|
c@96
|
353
|
c@96
|
354 #endif
|