comparison vamp-client/CapnpClient.h @ 94:a660dca988f8

More renaming
author Chris Cannam <c.cannam@qmul.ac.uk>
date Thu, 13 Oct 2016 14:10:55 +0100
parents
children b6ac26b72b59
comparison
equal deleted inserted replaced
93:fbce91785d35 94:a660dca988f8
1
2 #ifndef PIPER_CAPNP_CLIENT_H
3 #define PIPER_CAPNP_CLIENT_H
4
5 #include "Loader.h"
6 #include "PluginClient.h"
7 #include "PluginStub.h"
8 #include "SynchronousTransport.h"
9
10 #include "vamp-support/AssignedPluginHandleMapper.h"
11 #include "vamp-capnp/VampnProto.h"
12
13 #include <capnp/serialize.h>
14
15 namespace piper {
16 namespace vampclient {
17
18 class CapnpClient : public PluginClient,
19 public Loader
20 {
21 // unsigned to avoid undefined behaviour on possible wrap
22 typedef uint32_t ReqId;
23
24 class CompletenessChecker : public MessageCompletenessChecker {
25 public:
26 bool isComplete(const std::vector<char> &message) const override {
27 auto karr = toKJArray(message);
28 size_t words = karr.size();
29 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr);
30 if (words > expected) {
31 std::cerr << "WARNING: obtained more data than expected ("
32 << words << " " << sizeof(capnp::word)
33 << "-byte words, expected "
34 << expected << ")" << std::endl;
35 }
36 return words >= expected;
37 }
38 };
39
40 public:
41 CapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr?
42 m_transport(transport),
43 m_completenessChecker(new CompletenessChecker) {
44 transport->setCompletenessChecker(m_completenessChecker);
45 }
46
47 ~CapnpClient() {
48 delete m_completenessChecker;
49 }
50
51 //!!! obviously, factor out all repetitive guff
52
53 //!!! list and load are supposed to be called by application code,
54 //!!! but the rest are only supposed to be called by the plugin --
55 //!!! sort out the api here
56
57 Vamp::Plugin *
58 load(std::string key, float inputSampleRate, int adapterFlags) {
59
60 if (!m_transport->isOK()) {
61 throw std::runtime_error("Piper server failed to start");
62 }
63
64 Vamp::HostExt::PluginStaticData psd;
65 Vamp::HostExt::PluginConfiguration defaultConfig;
66 PluginHandleMapper::Handle handle =
67 serverLoad(key, inputSampleRate, adapterFlags, psd, defaultConfig);
68
69 Vamp::Plugin *plugin = new PluginStub(this,
70 key,
71 inputSampleRate,
72 adapterFlags,
73 psd,
74 defaultConfig);
75
76 m_mapper.addPlugin(handle, plugin);
77
78 return plugin;
79 }
80
81 PluginHandleMapper::Handle
82 serverLoad(std::string key, float inputSampleRate, int adapterFlags,
83 Vamp::HostExt::PluginStaticData &psd,
84 Vamp::HostExt::PluginConfiguration &defaultConfig) {
85
86 Vamp::HostExt::LoadRequest request;
87 request.pluginKey = key;
88 request.inputSampleRate = inputSampleRate;
89 request.adapterFlags = adapterFlags;
90
91 capnp::MallocMessageBuilder message;
92 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
93
94 VampnProto::buildRpcRequest_Load(builder, request);
95 ReqId id = getId();
96 builder.getId().setNumber(id);
97
98 auto arr = capnp::messageToFlatArray(message);
99
100 auto responseBuffer = m_transport->call(arr.asChars().begin(),
101 arr.asChars().size());
102
103 //!!! ... --> will also need some way to kill this process
104 //!!! (from another thread)
105
106 auto karr = toKJArray(responseBuffer);
107 capnp::FlatArrayMessageReader responseMessage(karr);
108 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
109
110 //!!! handle (explicit) error case
111
112 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
113
114 const LoadResponse::Reader &lr = reader.getResponse().getLoad();
115 VampnProto::readExtractorStaticData(psd, lr.getStaticData());
116 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
117 return lr.getHandle();
118 };
119
120 protected:
121 virtual
122 Vamp::Plugin::OutputList
123 configure(PluginStub *plugin,
124 Vamp::HostExt::PluginConfiguration config) override {
125
126 if (!m_transport->isOK()) {
127 throw std::runtime_error("Piper server failed to start");
128 }
129
130 Vamp::HostExt::ConfigurationRequest request;
131 request.plugin = plugin;
132 request.configuration = config;
133
134 capnp::MallocMessageBuilder message;
135 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
136
137 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
138 ReqId id = getId();
139 builder.getId().setNumber(id);
140
141 auto arr = capnp::messageToFlatArray(message);
142 auto responseBuffer = m_transport->call(arr.asChars().begin(),
143 arr.asChars().size());
144 auto karr = toKJArray(responseBuffer);
145 capnp::FlatArrayMessageReader responseMessage(karr);
146 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
147
148 //!!! handle (explicit) error case
149
150 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id);
151
152 Vamp::HostExt::ConfigurationResponse cr;
153 VampnProto::readConfigurationResponse(cr,
154 reader.getResponse().getConfigure(),
155 m_mapper);
156
157 return cr.outputs;
158 };
159
160 virtual
161 Vamp::Plugin::FeatureSet
162 process(PluginStub *plugin,
163 std::vector<std::vector<float> > inputBuffers,
164 Vamp::RealTime timestamp) override {
165
166 if (!m_transport->isOK()) {
167 throw std::runtime_error("Piper server failed to start");
168 }
169
170 Vamp::HostExt::ProcessRequest request;
171 request.plugin = plugin;
172 request.inputBuffers = inputBuffers;
173 request.timestamp = timestamp;
174
175 capnp::MallocMessageBuilder message;
176 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
177
178 VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
179 ReqId id = getId();
180 builder.getId().setNumber(id);
181
182 auto arr = capnp::messageToFlatArray(message);
183 auto responseBuffer = m_transport->call(arr.asChars().begin(),
184 arr.asChars().size());
185 auto karr = toKJArray(responseBuffer);
186 capnp::FlatArrayMessageReader responseMessage(karr);
187 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
188
189 //!!! handle (explicit) error case
190
191 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id);
192
193 Vamp::HostExt::ProcessResponse pr;
194 VampnProto::readProcessResponse(pr,
195 reader.getResponse().getProcess(),
196 m_mapper);
197
198 return pr.features;
199 }
200
201 virtual Vamp::Plugin::FeatureSet
202 finish(PluginStub *plugin) override {
203
204 if (!m_transport->isOK()) {
205 throw std::runtime_error("Piper server failed to start");
206 }
207
208 Vamp::HostExt::FinishRequest request;
209 request.plugin = plugin;
210
211 capnp::MallocMessageBuilder message;
212 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
213
214 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
215 ReqId id = getId();
216 builder.getId().setNumber(id);
217
218 auto arr = capnp::messageToFlatArray(message);
219 auto responseBuffer = m_transport->call(arr.asChars().begin(),
220 arr.asChars().size());
221 auto karr = toKJArray(responseBuffer);
222 capnp::FlatArrayMessageReader responseMessage(karr);
223 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
224
225 //!!! handle (explicit) error case
226
227 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id);
228
229 Vamp::HostExt::ProcessResponse pr;
230 VampnProto::readFinishResponse(pr,
231 reader.getResponse().getFinish(),
232 m_mapper);
233
234 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
235
236 // Don't delete the plugin. It's the plugin that is supposed
237 // to be calling us here
238
239 return pr.features;
240 }
241
242 virtual void
243 reset(PluginStub *plugin,
244 Vamp::HostExt::PluginConfiguration config) override {
245
246 // Reload the plugin on the server side, and configure it as requested
247
248 if (!m_transport->isOK()) {
249 throw std::runtime_error("Piper server failed to start");
250 }
251
252 if (m_mapper.havePlugin(plugin)) {
253 (void)finish(plugin); // server-side unload
254 }
255
256 Vamp::HostExt::PluginStaticData psd;
257 Vamp::HostExt::PluginConfiguration defaultConfig;
258 PluginHandleMapper::Handle handle =
259 serverLoad(plugin->getPluginKey(),
260 plugin->getInputSampleRate(),
261 plugin->getAdapterFlags(),
262 psd, defaultConfig);
263
264 m_mapper.addPlugin(handle, plugin);
265
266 (void)configure(plugin, config);
267 }
268
269 private:
270 AssignedPluginHandleMapper m_mapper;
271 ReqId getId() {
272 //!!! todo: mutex
273 static ReqId m_nextId = 0;
274 return m_nextId++;
275 }
276
277 static
278 kj::Array<capnp::word>
279 toKJArray(const std::vector<char> &buffer) {
280 // We could do this whole thing with fewer copies, but let's
281 // see whether it matters first
282 size_t wordSize = sizeof(capnp::word);
283 size_t words = buffer.size() / wordSize;
284 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words));
285 memcpy(karr.begin(), buffer.data(), words * wordSize);
286 return karr;
287 }
288
289 void
290 checkResponseType(const RpcResponse::Reader &r,
291 RpcResponse::Response::Which type,
292 ReqId id) {
293
294 if (r.getResponse().which() != type) {
295 throw std::runtime_error("Wrong response type");
296 }
297 if (ReqId(r.getId().getNumber()) != id) {
298 throw std::runtime_error("Wrong response id");
299 }
300 }
301
302 private:
303 SynchronousTransport *m_transport; //!!! I don't own this, but should I?
304 CompletenessChecker *m_completenessChecker; // I own this
305 };
306
307 }
308 }
309
310 #endif