c@90
|
1
|
c@90
|
2 #ifndef PIPER_CAPNP_CLIENT_H
|
c@90
|
3 #define PIPER_CAPNP_CLIENT_H
|
c@90
|
4
|
c@90
|
5 #include "PiperClient.h"
|
c@92
|
6 #include "PiperPluginStub.h"
|
c@90
|
7 #include "SynchronousTransport.h"
|
c@90
|
8
|
c@90
|
9 #include "vamp-support/AssignedPluginHandleMapper.h"
|
c@90
|
10 #include "vamp-capnp/VampnProto.h"
|
c@90
|
11
|
c@92
|
12 #include <capnp/serialize.h>
|
c@92
|
13
|
c@90
|
14 namespace piper { //!!! change
|
c@90
|
15
|
c@92
|
16 class PiperCapnpClient : public PiperPluginClientInterface,
|
c@92
|
17 public PiperLoaderInterface
|
c@90
|
18 {
|
c@90
|
19 // unsigned to avoid undefined behaviour on possible wrap
|
c@90
|
20 typedef uint32_t ReqId;
|
c@92
|
21
|
c@92
|
22 class CompletenessChecker : public MessageCompletenessChecker {
|
c@92
|
23 public:
|
c@92
|
24 bool isComplete(const std::vector<char> &message) const override {
|
c@92
|
25 auto karr = toKJArray(message);
|
c@92
|
26 size_t words = karr.size();
|
c@92
|
27 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr);
|
c@92
|
28 if (words > expected) {
|
c@92
|
29 std::cerr << "WARNING: obtained more data than expected ("
|
c@92
|
30 << words << " " << sizeof(capnp::word)
|
c@92
|
31 << "-byte words, expected "
|
c@92
|
32 << expected << ")" << std::endl;
|
c@92
|
33 }
|
c@92
|
34 return words >= expected;
|
c@92
|
35 }
|
c@92
|
36 };
|
c@90
|
37
|
c@90
|
38 public:
|
c@90
|
39 PiperCapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr?
|
c@92
|
40 m_transport(transport),
|
c@92
|
41 m_completenessChecker(new CompletenessChecker) {
|
c@92
|
42 transport->setCompletenessChecker(m_completenessChecker);
|
c@90
|
43 }
|
c@90
|
44
|
c@90
|
45 ~PiperCapnpClient() {
|
c@92
|
46 delete m_completenessChecker;
|
c@90
|
47 }
|
c@90
|
48
|
c@90
|
49 //!!! obviously, factor out all repetitive guff
|
c@90
|
50
|
c@90
|
51 //!!! list and load are supposed to be called by application code,
|
c@90
|
52 //!!! but the rest are only supposed to be called by the plugin --
|
c@90
|
53 //!!! sort out the api here
|
c@90
|
54
|
c@90
|
55 Vamp::Plugin *
|
c@90
|
56 load(std::string key, float inputSampleRate, int adapterFlags) {
|
c@90
|
57
|
c@90
|
58 if (!m_transport->isOK()) {
|
c@90
|
59 throw std::runtime_error("Piper server failed to start");
|
c@90
|
60 }
|
c@90
|
61
|
c@91
|
62 Vamp::HostExt::PluginStaticData psd;
|
c@91
|
63 Vamp::HostExt::PluginConfiguration defaultConfig;
|
c@91
|
64 PluginHandleMapper::Handle handle =
|
c@91
|
65 serverLoad(key, inputSampleRate, adapterFlags, psd, defaultConfig);
|
c@91
|
66
|
c@92
|
67 Vamp::Plugin *plugin = new PiperPluginStub(this,
|
c@91
|
68 key,
|
c@91
|
69 inputSampleRate,
|
c@91
|
70 adapterFlags,
|
c@91
|
71 psd,
|
c@91
|
72 defaultConfig);
|
c@91
|
73
|
c@91
|
74 m_mapper.addPlugin(handle, plugin);
|
c@91
|
75
|
c@91
|
76 return plugin;
|
c@91
|
77 }
|
c@91
|
78
|
c@91
|
79 PluginHandleMapper::Handle
|
c@91
|
80 serverLoad(std::string key, float inputSampleRate, int adapterFlags,
|
c@91
|
81 Vamp::HostExt::PluginStaticData &psd,
|
c@91
|
82 Vamp::HostExt::PluginConfiguration &defaultConfig) {
|
c@91
|
83
|
c@90
|
84 Vamp::HostExt::LoadRequest request;
|
c@90
|
85 request.pluginKey = key;
|
c@90
|
86 request.inputSampleRate = inputSampleRate;
|
c@90
|
87 request.adapterFlags = adapterFlags;
|
c@90
|
88
|
c@90
|
89 capnp::MallocMessageBuilder message;
|
c@90
|
90 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@90
|
91
|
c@90
|
92 VampnProto::buildRpcRequest_Load(builder, request);
|
c@90
|
93 ReqId id = getId();
|
c@90
|
94 builder.getId().setNumber(id);
|
c@90
|
95
|
c@91
|
96 auto arr = capnp::messageToFlatArray(message);
|
c@90
|
97
|
c@90
|
98 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@90
|
99 arr.asChars().size());
|
c@90
|
100
|
c@90
|
101 //!!! ... --> will also need some way to kill this process
|
c@90
|
102 //!!! (from another thread)
|
c@90
|
103
|
c@90
|
104 auto karr = toKJArray(responseBuffer);
|
c@90
|
105 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@90
|
106 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@90
|
107
|
c@90
|
108 //!!! handle (explicit) error case
|
c@90
|
109
|
c@90
|
110 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
|
c@90
|
111
|
c@90
|
112 const LoadResponse::Reader &lr = reader.getResponse().getLoad();
|
c@90
|
113 VampnProto::readExtractorStaticData(psd, lr.getStaticData());
|
c@90
|
114 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
|
c@91
|
115 return lr.getHandle();
|
c@90
|
116 };
|
c@90
|
117
|
c@90
|
118 protected:
|
c@90
|
119 virtual
|
c@90
|
120 Vamp::Plugin::OutputList
|
c@92
|
121 configure(PiperPluginStub *plugin,
|
c@90
|
122 Vamp::HostExt::PluginConfiguration config) override {
|
c@90
|
123
|
c@90
|
124 if (!m_transport->isOK()) {
|
c@90
|
125 throw std::runtime_error("Piper server failed to start");
|
c@90
|
126 }
|
c@90
|
127
|
c@90
|
128 Vamp::HostExt::ConfigurationRequest request;
|
c@90
|
129 request.plugin = plugin;
|
c@90
|
130 request.configuration = config;
|
c@90
|
131
|
c@90
|
132 capnp::MallocMessageBuilder message;
|
c@90
|
133 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@90
|
134
|
c@90
|
135 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
|
c@90
|
136 ReqId id = getId();
|
c@90
|
137 builder.getId().setNumber(id);
|
c@90
|
138
|
c@91
|
139 auto arr = capnp::messageToFlatArray(message);
|
c@90
|
140 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@90
|
141 arr.asChars().size());
|
c@90
|
142 auto karr = toKJArray(responseBuffer);
|
c@90
|
143 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@90
|
144 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@90
|
145
|
c@90
|
146 //!!! handle (explicit) error case
|
c@90
|
147
|
c@90
|
148 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id);
|
c@90
|
149
|
c@90
|
150 Vamp::HostExt::ConfigurationResponse cr;
|
c@90
|
151 VampnProto::readConfigurationResponse(cr,
|
c@90
|
152 reader.getResponse().getConfigure(),
|
c@90
|
153 m_mapper);
|
c@90
|
154
|
c@90
|
155 return cr.outputs;
|
c@90
|
156 };
|
c@90
|
157
|
c@90
|
158 virtual
|
c@90
|
159 Vamp::Plugin::FeatureSet
|
c@92
|
160 process(PiperPluginStub *plugin,
|
c@90
|
161 std::vector<std::vector<float> > inputBuffers,
|
c@90
|
162 Vamp::RealTime timestamp) override {
|
c@90
|
163
|
c@90
|
164 if (!m_transport->isOK()) {
|
c@90
|
165 throw std::runtime_error("Piper server failed to start");
|
c@90
|
166 }
|
c@90
|
167
|
c@90
|
168 Vamp::HostExt::ProcessRequest request;
|
c@90
|
169 request.plugin = plugin;
|
c@90
|
170 request.inputBuffers = inputBuffers;
|
c@90
|
171 request.timestamp = timestamp;
|
c@90
|
172
|
c@90
|
173 capnp::MallocMessageBuilder message;
|
c@90
|
174 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@90
|
175
|
c@90
|
176 VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
|
c@90
|
177 ReqId id = getId();
|
c@90
|
178 builder.getId().setNumber(id);
|
c@90
|
179
|
c@91
|
180 auto arr = capnp::messageToFlatArray(message);
|
c@90
|
181 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@90
|
182 arr.asChars().size());
|
c@90
|
183 auto karr = toKJArray(responseBuffer);
|
c@90
|
184 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@90
|
185 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@90
|
186
|
c@90
|
187 //!!! handle (explicit) error case
|
c@90
|
188
|
c@90
|
189 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id);
|
c@90
|
190
|
c@90
|
191 Vamp::HostExt::ProcessResponse pr;
|
c@90
|
192 VampnProto::readProcessResponse(pr,
|
c@90
|
193 reader.getResponse().getProcess(),
|
c@90
|
194 m_mapper);
|
c@90
|
195
|
c@90
|
196 return pr.features;
|
c@90
|
197 }
|
c@90
|
198
|
c@90
|
199 virtual Vamp::Plugin::FeatureSet
|
c@92
|
200 finish(PiperPluginStub *plugin) override {
|
c@90
|
201
|
c@90
|
202 if (!m_transport->isOK()) {
|
c@90
|
203 throw std::runtime_error("Piper server failed to start");
|
c@90
|
204 }
|
c@90
|
205
|
c@90
|
206 Vamp::HostExt::FinishRequest request;
|
c@90
|
207 request.plugin = plugin;
|
c@90
|
208
|
c@90
|
209 capnp::MallocMessageBuilder message;
|
c@90
|
210 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@90
|
211
|
c@90
|
212 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
|
c@90
|
213 ReqId id = getId();
|
c@90
|
214 builder.getId().setNumber(id);
|
c@90
|
215
|
c@91
|
216 auto arr = capnp::messageToFlatArray(message);
|
c@90
|
217 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@90
|
218 arr.asChars().size());
|
c@90
|
219 auto karr = toKJArray(responseBuffer);
|
c@90
|
220 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@90
|
221 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@90
|
222
|
c@90
|
223 //!!! handle (explicit) error case
|
c@90
|
224
|
c@90
|
225 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id);
|
c@90
|
226
|
c@90
|
227 Vamp::HostExt::ProcessResponse pr;
|
c@90
|
228 VampnProto::readFinishResponse(pr,
|
c@90
|
229 reader.getResponse().getFinish(),
|
c@90
|
230 m_mapper);
|
c@90
|
231
|
c@90
|
232 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
|
c@90
|
233
|
c@90
|
234 // Don't delete the plugin. It's the plugin that is supposed
|
c@90
|
235 // to be calling us here
|
c@90
|
236
|
c@90
|
237 return pr.features;
|
c@90
|
238 }
|
c@90
|
239
|
c@91
|
240 virtual void
|
c@92
|
241 reset(PiperPluginStub *plugin,
|
c@91
|
242 Vamp::HostExt::PluginConfiguration config) override {
|
c@91
|
243
|
c@91
|
244 // Reload the plugin on the server side, and configure it as requested
|
c@91
|
245
|
c@91
|
246 if (!m_transport->isOK()) {
|
c@91
|
247 throw std::runtime_error("Piper server failed to start");
|
c@91
|
248 }
|
c@91
|
249
|
c@91
|
250 if (m_mapper.havePlugin(plugin)) {
|
c@91
|
251 (void)finish(plugin); // server-side unload
|
c@91
|
252 }
|
c@91
|
253
|
c@91
|
254 Vamp::HostExt::PluginStaticData psd;
|
c@91
|
255 Vamp::HostExt::PluginConfiguration defaultConfig;
|
c@91
|
256 PluginHandleMapper::Handle handle =
|
c@91
|
257 serverLoad(plugin->getPluginKey(),
|
c@91
|
258 plugin->getInputSampleRate(),
|
c@91
|
259 plugin->getAdapterFlags(),
|
c@91
|
260 psd, defaultConfig);
|
c@91
|
261
|
c@91
|
262 m_mapper.addPlugin(handle, plugin);
|
c@91
|
263
|
c@91
|
264 (void)configure(plugin, config);
|
c@91
|
265 }
|
c@91
|
266
|
c@90
|
267 private:
|
c@90
|
268 AssignedPluginHandleMapper m_mapper;
|
c@90
|
269 ReqId getId() {
|
c@90
|
270 //!!! todo: mutex
|
c@90
|
271 static ReqId m_nextId = 0;
|
c@90
|
272 return m_nextId++;
|
c@90
|
273 }
|
c@90
|
274
|
c@92
|
275 static
|
c@90
|
276 kj::Array<capnp::word>
|
c@90
|
277 toKJArray(const std::vector<char> &buffer) {
|
c@90
|
278 // We could do this whole thing with fewer copies, but let's
|
c@90
|
279 // see whether it matters first
|
c@90
|
280 size_t wordSize = sizeof(capnp::word);
|
c@90
|
281 size_t words = buffer.size() / wordSize;
|
c@90
|
282 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words));
|
c@90
|
283 memcpy(karr.begin(), buffer.data(), words * wordSize);
|
c@90
|
284 return karr;
|
c@90
|
285 }
|
c@90
|
286
|
c@90
|
287 void
|
c@90
|
288 checkResponseType(const RpcResponse::Reader &r,
|
c@90
|
289 RpcResponse::Response::Which type,
|
c@90
|
290 ReqId id) {
|
c@90
|
291
|
c@90
|
292 if (r.getResponse().which() != type) {
|
c@90
|
293 throw std::runtime_error("Wrong response type");
|
c@90
|
294 }
|
c@90
|
295 if (ReqId(r.getId().getNumber()) != id) {
|
c@90
|
296 throw std::runtime_error("Wrong response id");
|
c@90
|
297 }
|
c@90
|
298 }
|
c@90
|
299
|
c@90
|
300 private:
|
c@90
|
301 SynchronousTransport *m_transport; //!!! I don't own this, but should I?
|
c@92
|
302 CompletenessChecker *m_completenessChecker; // I own this
|
c@90
|
303 };
|
c@90
|
304
|
c@90
|
305 }
|
c@90
|
306
|
c@90
|
307 #endif
|