c@94
|
1
|
c@94
|
2 #ifndef PIPER_CAPNP_CLIENT_H
|
c@94
|
3 #define PIPER_CAPNP_CLIENT_H
|
c@94
|
4
|
c@94
|
5 #include "Loader.h"
|
c@94
|
6 #include "PluginClient.h"
|
c@94
|
7 #include "PluginStub.h"
|
c@94
|
8 #include "SynchronousTransport.h"
|
c@94
|
9
|
c@94
|
10 #include "vamp-support/AssignedPluginHandleMapper.h"
|
c@94
|
11 #include "vamp-capnp/VampnProto.h"
|
c@94
|
12
|
c@94
|
13 #include <capnp/serialize.h>
|
c@94
|
14
|
c@94
|
15 namespace piper {
|
c@94
|
16 namespace vampclient {
|
c@94
|
17
|
c@94
|
18 class CapnpClient : public PluginClient,
|
c@94
|
19 public Loader
|
c@94
|
20 {
|
c@94
|
21 // unsigned to avoid undefined behaviour on possible wrap
|
c@94
|
22 typedef uint32_t ReqId;
|
c@94
|
23
|
c@94
|
24 class CompletenessChecker : public MessageCompletenessChecker {
|
c@94
|
25 public:
|
c@94
|
26 bool isComplete(const std::vector<char> &message) const override {
|
c@94
|
27 auto karr = toKJArray(message);
|
c@94
|
28 size_t words = karr.size();
|
c@94
|
29 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr);
|
c@94
|
30 if (words > expected) {
|
c@94
|
31 std::cerr << "WARNING: obtained more data than expected ("
|
c@94
|
32 << words << " " << sizeof(capnp::word)
|
c@94
|
33 << "-byte words, expected "
|
c@94
|
34 << expected << ")" << std::endl;
|
c@94
|
35 }
|
c@94
|
36 return words >= expected;
|
c@94
|
37 }
|
c@94
|
38 };
|
c@94
|
39
|
c@94
|
40 public:
|
c@94
|
41 CapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr?
|
c@94
|
42 m_transport(transport),
|
c@94
|
43 m_completenessChecker(new CompletenessChecker) {
|
c@94
|
44 transport->setCompletenessChecker(m_completenessChecker);
|
c@94
|
45 }
|
c@94
|
46
|
c@94
|
47 ~CapnpClient() {
|
c@94
|
48 delete m_completenessChecker;
|
c@94
|
49 }
|
c@94
|
50
|
c@94
|
51 //!!! obviously, factor out all repetitive guff
|
c@94
|
52
|
c@94
|
53 //!!! list and load are supposed to be called by application code,
|
c@94
|
54 //!!! but the rest are only supposed to be called by the plugin --
|
c@94
|
55 //!!! sort out the api here
|
c@95
|
56
|
c@95
|
57 // Loader methods:
|
c@95
|
58
|
c@95
|
59 Vamp::HostExt::ListResponse
|
c@95
|
60 listPluginData() override {
|
c@94
|
61
|
c@94
|
62 if (!m_transport->isOK()) {
|
c@94
|
63 throw std::runtime_error("Piper server failed to start");
|
c@94
|
64 }
|
c@94
|
65
|
c@94
|
66 capnp::MallocMessageBuilder message;
|
c@94
|
67 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@95
|
68 VampnProto::buildRpcRequest_List(builder);
|
c@94
|
69 ReqId id = getId();
|
c@94
|
70 builder.getId().setNumber(id);
|
c@94
|
71
|
c@95
|
72 //!!! pure boilerplate:
|
c@94
|
73 auto arr = capnp::messageToFlatArray(message);
|
c@94
|
74 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@94
|
75 arr.asChars().size());
|
c@94
|
76 auto karr = toKJArray(responseBuffer);
|
c@94
|
77 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@94
|
78 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@94
|
79
|
c@95
|
80 checkResponseType(reader, RpcResponse::Response::Which::LIST, id);
|
c@94
|
81
|
c@95
|
82 Vamp::HostExt::ListResponse lr;
|
c@95
|
83 VampnProto::readListResponse(lr, reader.getResponse().getList());
|
c@95
|
84 return lr;
|
c@95
|
85 }
|
c@95
|
86
|
c@95
|
87 Vamp::HostExt::LoadResponse
|
c@95
|
88 loadPlugin(const Vamp::HostExt::LoadRequest &req) override {
|
c@94
|
89
|
c@95
|
90 if (!m_transport->isOK()) {
|
c@95
|
91 throw std::runtime_error("Piper server failed to start");
|
c@95
|
92 }
|
c@95
|
93
|
c@95
|
94 Vamp::HostExt::LoadResponse resp;
|
c@95
|
95 PluginHandleMapper::Handle handle = serverLoad(req.pluginKey,
|
c@95
|
96 req.inputSampleRate,
|
c@95
|
97 req.adapterFlags,
|
c@95
|
98 resp.staticData,
|
c@95
|
99 resp.defaultConfiguration);
|
c@95
|
100
|
c@95
|
101 Vamp::Plugin *plugin = new PluginStub(this,
|
c@95
|
102 req.pluginKey,
|
c@95
|
103 req.inputSampleRate,
|
c@95
|
104 req.adapterFlags,
|
c@95
|
105 resp.staticData,
|
c@95
|
106 resp.defaultConfiguration);
|
c@95
|
107
|
c@95
|
108 m_mapper.addPlugin(handle, plugin);
|
c@95
|
109
|
c@95
|
110 resp.plugin = plugin;
|
c@95
|
111 return resp;
|
c@95
|
112 }
|
c@95
|
113
|
c@95
|
114 // PluginClient methods:
|
c@95
|
115
|
c@94
|
116 virtual
|
c@94
|
117 Vamp::Plugin::OutputList
|
c@94
|
118 configure(PluginStub *plugin,
|
c@94
|
119 Vamp::HostExt::PluginConfiguration config) override {
|
c@94
|
120
|
c@94
|
121 if (!m_transport->isOK()) {
|
c@94
|
122 throw std::runtime_error("Piper server failed to start");
|
c@94
|
123 }
|
c@94
|
124
|
c@94
|
125 Vamp::HostExt::ConfigurationRequest request;
|
c@94
|
126 request.plugin = plugin;
|
c@94
|
127 request.configuration = config;
|
c@94
|
128
|
c@94
|
129 capnp::MallocMessageBuilder message;
|
c@94
|
130 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@94
|
131
|
c@94
|
132 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
|
c@94
|
133 ReqId id = getId();
|
c@94
|
134 builder.getId().setNumber(id);
|
c@94
|
135
|
c@94
|
136 auto arr = capnp::messageToFlatArray(message);
|
c@94
|
137 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@94
|
138 arr.asChars().size());
|
c@94
|
139 auto karr = toKJArray(responseBuffer);
|
c@94
|
140 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@94
|
141 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@94
|
142
|
c@94
|
143 //!!! handle (explicit) error case
|
c@94
|
144
|
c@94
|
145 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id);
|
c@94
|
146
|
c@94
|
147 Vamp::HostExt::ConfigurationResponse cr;
|
c@94
|
148 VampnProto::readConfigurationResponse(cr,
|
c@94
|
149 reader.getResponse().getConfigure(),
|
c@94
|
150 m_mapper);
|
c@94
|
151
|
c@94
|
152 return cr.outputs;
|
c@94
|
153 };
|
c@94
|
154
|
c@94
|
155 virtual
|
c@94
|
156 Vamp::Plugin::FeatureSet
|
c@94
|
157 process(PluginStub *plugin,
|
c@94
|
158 std::vector<std::vector<float> > inputBuffers,
|
c@94
|
159 Vamp::RealTime timestamp) override {
|
c@94
|
160
|
c@94
|
161 if (!m_transport->isOK()) {
|
c@94
|
162 throw std::runtime_error("Piper server failed to start");
|
c@94
|
163 }
|
c@94
|
164
|
c@94
|
165 Vamp::HostExt::ProcessRequest request;
|
c@94
|
166 request.plugin = plugin;
|
c@94
|
167 request.inputBuffers = inputBuffers;
|
c@94
|
168 request.timestamp = timestamp;
|
c@94
|
169
|
c@94
|
170 capnp::MallocMessageBuilder message;
|
c@94
|
171 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@94
|
172
|
c@94
|
173 VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
|
c@94
|
174 ReqId id = getId();
|
c@94
|
175 builder.getId().setNumber(id);
|
c@94
|
176
|
c@94
|
177 auto arr = capnp::messageToFlatArray(message);
|
c@94
|
178 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@94
|
179 arr.asChars().size());
|
c@94
|
180 auto karr = toKJArray(responseBuffer);
|
c@94
|
181 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@94
|
182 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@94
|
183
|
c@94
|
184 //!!! handle (explicit) error case
|
c@94
|
185
|
c@94
|
186 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id);
|
c@94
|
187
|
c@94
|
188 Vamp::HostExt::ProcessResponse pr;
|
c@94
|
189 VampnProto::readProcessResponse(pr,
|
c@94
|
190 reader.getResponse().getProcess(),
|
c@94
|
191 m_mapper);
|
c@94
|
192
|
c@94
|
193 return pr.features;
|
c@94
|
194 }
|
c@94
|
195
|
c@94
|
196 virtual Vamp::Plugin::FeatureSet
|
c@94
|
197 finish(PluginStub *plugin) override {
|
c@94
|
198
|
c@94
|
199 if (!m_transport->isOK()) {
|
c@94
|
200 throw std::runtime_error("Piper server failed to start");
|
c@94
|
201 }
|
c@94
|
202
|
c@94
|
203 Vamp::HostExt::FinishRequest request;
|
c@94
|
204 request.plugin = plugin;
|
c@94
|
205
|
c@94
|
206 capnp::MallocMessageBuilder message;
|
c@94
|
207 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@94
|
208
|
c@94
|
209 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
|
c@94
|
210 ReqId id = getId();
|
c@94
|
211 builder.getId().setNumber(id);
|
c@94
|
212
|
c@94
|
213 auto arr = capnp::messageToFlatArray(message);
|
c@94
|
214 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@94
|
215 arr.asChars().size());
|
c@94
|
216 auto karr = toKJArray(responseBuffer);
|
c@94
|
217 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@94
|
218 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@94
|
219
|
c@94
|
220 //!!! handle (explicit) error case
|
c@94
|
221
|
c@94
|
222 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id);
|
c@94
|
223
|
c@94
|
224 Vamp::HostExt::ProcessResponse pr;
|
c@94
|
225 VampnProto::readFinishResponse(pr,
|
c@94
|
226 reader.getResponse().getFinish(),
|
c@94
|
227 m_mapper);
|
c@94
|
228
|
c@94
|
229 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
|
c@94
|
230
|
c@94
|
231 // Don't delete the plugin. It's the plugin that is supposed
|
c@94
|
232 // to be calling us here
|
c@94
|
233
|
c@94
|
234 return pr.features;
|
c@94
|
235 }
|
c@94
|
236
|
c@94
|
237 virtual void
|
c@94
|
238 reset(PluginStub *plugin,
|
c@94
|
239 Vamp::HostExt::PluginConfiguration config) override {
|
c@94
|
240
|
c@94
|
241 // Reload the plugin on the server side, and configure it as requested
|
c@94
|
242
|
c@94
|
243 if (!m_transport->isOK()) {
|
c@94
|
244 throw std::runtime_error("Piper server failed to start");
|
c@94
|
245 }
|
c@94
|
246
|
c@94
|
247 if (m_mapper.havePlugin(plugin)) {
|
c@94
|
248 (void)finish(plugin); // server-side unload
|
c@94
|
249 }
|
c@94
|
250
|
c@94
|
251 Vamp::HostExt::PluginStaticData psd;
|
c@94
|
252 Vamp::HostExt::PluginConfiguration defaultConfig;
|
c@94
|
253 PluginHandleMapper::Handle handle =
|
c@94
|
254 serverLoad(plugin->getPluginKey(),
|
c@94
|
255 plugin->getInputSampleRate(),
|
c@94
|
256 plugin->getAdapterFlags(),
|
c@94
|
257 psd, defaultConfig);
|
c@94
|
258
|
c@94
|
259 m_mapper.addPlugin(handle, plugin);
|
c@94
|
260
|
c@94
|
261 (void)configure(plugin, config);
|
c@94
|
262 }
|
c@94
|
263
|
c@94
|
264 private:
|
c@94
|
265 AssignedPluginHandleMapper m_mapper;
|
c@94
|
266 ReqId getId() {
|
c@94
|
267 //!!! todo: mutex
|
c@94
|
268 static ReqId m_nextId = 0;
|
c@94
|
269 return m_nextId++;
|
c@94
|
270 }
|
c@94
|
271
|
c@94
|
272 static
|
c@94
|
273 kj::Array<capnp::word>
|
c@94
|
274 toKJArray(const std::vector<char> &buffer) {
|
c@94
|
275 // We could do this whole thing with fewer copies, but let's
|
c@94
|
276 // see whether it matters first
|
c@94
|
277 size_t wordSize = sizeof(capnp::word);
|
c@94
|
278 size_t words = buffer.size() / wordSize;
|
c@94
|
279 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words));
|
c@94
|
280 memcpy(karr.begin(), buffer.data(), words * wordSize);
|
c@94
|
281 return karr;
|
c@94
|
282 }
|
c@94
|
283
|
c@94
|
284 void
|
c@94
|
285 checkResponseType(const RpcResponse::Reader &r,
|
c@94
|
286 RpcResponse::Response::Which type,
|
c@94
|
287 ReqId id) {
|
c@94
|
288
|
c@94
|
289 if (r.getResponse().which() != type) {
|
c@94
|
290 throw std::runtime_error("Wrong response type");
|
c@94
|
291 }
|
c@94
|
292 if (ReqId(r.getId().getNumber()) != id) {
|
c@94
|
293 throw std::runtime_error("Wrong response id");
|
c@94
|
294 }
|
c@94
|
295 }
|
c@95
|
296
|
c@95
|
297 PluginHandleMapper::Handle
|
c@95
|
298 serverLoad(std::string key, float inputSampleRate, int adapterFlags,
|
c@95
|
299 Vamp::HostExt::PluginStaticData &psd,
|
c@95
|
300 Vamp::HostExt::PluginConfiguration &defaultConfig) {
|
c@95
|
301
|
c@95
|
302 Vamp::HostExt::LoadRequest request;
|
c@95
|
303 request.pluginKey = key;
|
c@95
|
304 request.inputSampleRate = inputSampleRate;
|
c@95
|
305 request.adapterFlags = adapterFlags;
|
c@95
|
306
|
c@95
|
307 capnp::MallocMessageBuilder message;
|
c@95
|
308 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@95
|
309
|
c@95
|
310 VampnProto::buildRpcRequest_Load(builder, request);
|
c@95
|
311 ReqId id = getId();
|
c@95
|
312 builder.getId().setNumber(id);
|
c@95
|
313
|
c@95
|
314 auto arr = capnp::messageToFlatArray(message);
|
c@95
|
315
|
c@95
|
316 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@95
|
317 arr.asChars().size());
|
c@95
|
318
|
c@95
|
319 //!!! ... --> will also need some way to kill this process
|
c@95
|
320 //!!! (from another thread)
|
c@95
|
321
|
c@95
|
322 auto karr = toKJArray(responseBuffer);
|
c@95
|
323 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@95
|
324 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@95
|
325
|
c@95
|
326 //!!! handle (explicit) error case
|
c@95
|
327
|
c@95
|
328 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
|
c@95
|
329
|
c@95
|
330 const LoadResponse::Reader &lr = reader.getResponse().getLoad();
|
c@95
|
331 VampnProto::readExtractorStaticData(psd, lr.getStaticData());
|
c@95
|
332 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
|
c@95
|
333 return lr.getHandle();
|
c@95
|
334 };
|
c@94
|
335
|
c@94
|
336 private:
|
c@94
|
337 SynchronousTransport *m_transport; //!!! I don't own this, but should I?
|
c@94
|
338 CompletenessChecker *m_completenessChecker; // I own this
|
c@94
|
339 };
|
c@94
|
340
|
c@94
|
341 }
|
c@94
|
342 }
|
c@94
|
343
|
c@94
|
344 #endif
|