c@90
|
1
|
c@90
|
2 #ifndef PIPER_CAPNP_CLIENT_H
|
c@90
|
3 #define PIPER_CAPNP_CLIENT_H
|
c@90
|
4
|
c@90
|
5 #include "PiperClient.h"
|
c@90
|
6 #include "SynchronousTransport.h"
|
c@90
|
7
|
c@90
|
8 #include "vamp-support/AssignedPluginHandleMapper.h"
|
c@90
|
9 #include "vamp-capnp/VampnProto.h"
|
c@90
|
10
|
c@90
|
11 namespace piper { //!!! change
|
c@90
|
12
|
c@90
|
13 class PiperCapnpClient : public PiperStubPluginClientInterface
|
c@90
|
14 {
|
c@90
|
15 // unsigned to avoid undefined behaviour on possible wrap
|
c@90
|
16 typedef uint32_t ReqId;
|
c@90
|
17
|
c@90
|
18 public:
|
c@90
|
19 PiperCapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr?
|
c@90
|
20 m_transport(transport) {
|
c@90
|
21 }
|
c@90
|
22
|
c@90
|
23 ~PiperCapnpClient() {
|
c@90
|
24 }
|
c@90
|
25
|
c@90
|
26 //!!! obviously, factor out all repetitive guff
|
c@90
|
27
|
c@90
|
28 //!!! list and load are supposed to be called by application code,
|
c@90
|
29 //!!! but the rest are only supposed to be called by the plugin --
|
c@90
|
30 //!!! sort out the api here
|
c@90
|
31
|
c@90
|
32 Vamp::Plugin *
|
c@90
|
33 load(std::string key, float inputSampleRate, int adapterFlags) {
|
c@90
|
34
|
c@90
|
35 if (!m_transport->isOK()) {
|
c@90
|
36 throw std::runtime_error("Piper server failed to start");
|
c@90
|
37 }
|
c@90
|
38
|
c@91
|
39 Vamp::HostExt::PluginStaticData psd;
|
c@91
|
40 Vamp::HostExt::PluginConfiguration defaultConfig;
|
c@91
|
41 PluginHandleMapper::Handle handle =
|
c@91
|
42 serverLoad(key, inputSampleRate, adapterFlags, psd, defaultConfig);
|
c@91
|
43
|
c@91
|
44 Vamp::Plugin *plugin = new PiperStubPlugin(this,
|
c@91
|
45 key,
|
c@91
|
46 inputSampleRate,
|
c@91
|
47 adapterFlags,
|
c@91
|
48 psd,
|
c@91
|
49 defaultConfig);
|
c@91
|
50
|
c@91
|
51 m_mapper.addPlugin(handle, plugin);
|
c@91
|
52
|
c@91
|
53 return plugin;
|
c@91
|
54 }
|
c@91
|
55
|
c@91
|
56 PluginHandleMapper::Handle
|
c@91
|
57 serverLoad(std::string key, float inputSampleRate, int adapterFlags,
|
c@91
|
58 Vamp::HostExt::PluginStaticData &psd,
|
c@91
|
59 Vamp::HostExt::PluginConfiguration &defaultConfig) {
|
c@91
|
60
|
c@90
|
61 Vamp::HostExt::LoadRequest request;
|
c@90
|
62 request.pluginKey = key;
|
c@90
|
63 request.inputSampleRate = inputSampleRate;
|
c@90
|
64 request.adapterFlags = adapterFlags;
|
c@90
|
65
|
c@90
|
66 capnp::MallocMessageBuilder message;
|
c@90
|
67 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@90
|
68
|
c@90
|
69 VampnProto::buildRpcRequest_Load(builder, request);
|
c@90
|
70 ReqId id = getId();
|
c@90
|
71 builder.getId().setNumber(id);
|
c@90
|
72
|
c@91
|
73 auto arr = capnp::messageToFlatArray(message);
|
c@90
|
74
|
c@90
|
75 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@90
|
76 arr.asChars().size());
|
c@90
|
77
|
c@90
|
78 //!!! ... --> will also need some way to kill this process
|
c@90
|
79 //!!! (from another thread)
|
c@90
|
80
|
c@90
|
81 auto karr = toKJArray(responseBuffer);
|
c@90
|
82 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@90
|
83 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@90
|
84
|
c@90
|
85 //!!! handle (explicit) error case
|
c@90
|
86
|
c@90
|
87 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
|
c@90
|
88
|
c@90
|
89 const LoadResponse::Reader &lr = reader.getResponse().getLoad();
|
c@90
|
90 VampnProto::readExtractorStaticData(psd, lr.getStaticData());
|
c@90
|
91 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
|
c@91
|
92 return lr.getHandle();
|
c@90
|
93 };
|
c@90
|
94
|
c@90
|
95 protected:
|
c@90
|
96 virtual
|
c@90
|
97 Vamp::Plugin::OutputList
|
c@90
|
98 configure(PiperStubPlugin *plugin,
|
c@90
|
99 Vamp::HostExt::PluginConfiguration config) override {
|
c@90
|
100
|
c@90
|
101 if (!m_transport->isOK()) {
|
c@90
|
102 throw std::runtime_error("Piper server failed to start");
|
c@90
|
103 }
|
c@90
|
104
|
c@90
|
105 Vamp::HostExt::ConfigurationRequest request;
|
c@90
|
106 request.plugin = plugin;
|
c@90
|
107 request.configuration = config;
|
c@90
|
108
|
c@90
|
109 capnp::MallocMessageBuilder message;
|
c@90
|
110 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@90
|
111
|
c@90
|
112 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
|
c@90
|
113 ReqId id = getId();
|
c@90
|
114 builder.getId().setNumber(id);
|
c@90
|
115
|
c@91
|
116 auto arr = capnp::messageToFlatArray(message);
|
c@90
|
117 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@90
|
118 arr.asChars().size());
|
c@90
|
119 auto karr = toKJArray(responseBuffer);
|
c@90
|
120 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@90
|
121 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@90
|
122
|
c@90
|
123 //!!! handle (explicit) error case
|
c@90
|
124
|
c@90
|
125 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id);
|
c@90
|
126
|
c@90
|
127 Vamp::HostExt::ConfigurationResponse cr;
|
c@90
|
128 VampnProto::readConfigurationResponse(cr,
|
c@90
|
129 reader.getResponse().getConfigure(),
|
c@90
|
130 m_mapper);
|
c@90
|
131
|
c@90
|
132 return cr.outputs;
|
c@90
|
133 };
|
c@90
|
134
|
c@90
|
135 virtual
|
c@90
|
136 Vamp::Plugin::FeatureSet
|
c@90
|
137 process(PiperStubPlugin *plugin,
|
c@90
|
138 std::vector<std::vector<float> > inputBuffers,
|
c@90
|
139 Vamp::RealTime timestamp) override {
|
c@90
|
140
|
c@90
|
141 if (!m_transport->isOK()) {
|
c@90
|
142 throw std::runtime_error("Piper server failed to start");
|
c@90
|
143 }
|
c@90
|
144
|
c@90
|
145 Vamp::HostExt::ProcessRequest request;
|
c@90
|
146 request.plugin = plugin;
|
c@90
|
147 request.inputBuffers = inputBuffers;
|
c@90
|
148 request.timestamp = timestamp;
|
c@90
|
149
|
c@90
|
150 capnp::MallocMessageBuilder message;
|
c@90
|
151 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@90
|
152
|
c@90
|
153 VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
|
c@90
|
154 ReqId id = getId();
|
c@90
|
155 builder.getId().setNumber(id);
|
c@90
|
156
|
c@91
|
157 auto arr = capnp::messageToFlatArray(message);
|
c@90
|
158 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@90
|
159 arr.asChars().size());
|
c@90
|
160 auto karr = toKJArray(responseBuffer);
|
c@90
|
161 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@90
|
162 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@90
|
163
|
c@90
|
164 //!!! handle (explicit) error case
|
c@90
|
165
|
c@90
|
166 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id);
|
c@90
|
167
|
c@90
|
168 Vamp::HostExt::ProcessResponse pr;
|
c@90
|
169 VampnProto::readProcessResponse(pr,
|
c@90
|
170 reader.getResponse().getProcess(),
|
c@90
|
171 m_mapper);
|
c@90
|
172
|
c@90
|
173 return pr.features;
|
c@90
|
174 }
|
c@90
|
175
|
c@90
|
176 virtual Vamp::Plugin::FeatureSet
|
c@90
|
177 finish(PiperStubPlugin *plugin) override {
|
c@90
|
178
|
c@90
|
179 if (!m_transport->isOK()) {
|
c@90
|
180 throw std::runtime_error("Piper server failed to start");
|
c@90
|
181 }
|
c@90
|
182
|
c@90
|
183 Vamp::HostExt::FinishRequest request;
|
c@90
|
184 request.plugin = plugin;
|
c@90
|
185
|
c@90
|
186 capnp::MallocMessageBuilder message;
|
c@90
|
187 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
|
c@90
|
188
|
c@90
|
189 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
|
c@90
|
190 ReqId id = getId();
|
c@90
|
191 builder.getId().setNumber(id);
|
c@90
|
192
|
c@91
|
193 auto arr = capnp::messageToFlatArray(message);
|
c@90
|
194 auto responseBuffer = m_transport->call(arr.asChars().begin(),
|
c@90
|
195 arr.asChars().size());
|
c@90
|
196 auto karr = toKJArray(responseBuffer);
|
c@90
|
197 capnp::FlatArrayMessageReader responseMessage(karr);
|
c@90
|
198 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
|
c@90
|
199
|
c@90
|
200 //!!! handle (explicit) error case
|
c@90
|
201
|
c@90
|
202 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id);
|
c@90
|
203
|
c@90
|
204 Vamp::HostExt::ProcessResponse pr;
|
c@90
|
205 VampnProto::readFinishResponse(pr,
|
c@90
|
206 reader.getResponse().getFinish(),
|
c@90
|
207 m_mapper);
|
c@90
|
208
|
c@90
|
209 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
|
c@90
|
210
|
c@90
|
211 // Don't delete the plugin. It's the plugin that is supposed
|
c@90
|
212 // to be calling us here
|
c@90
|
213
|
c@90
|
214 return pr.features;
|
c@90
|
215 }
|
c@90
|
216
|
c@91
|
217 virtual void
|
c@91
|
218 reset(PiperStubPlugin *plugin,
|
c@91
|
219 Vamp::HostExt::PluginConfiguration config) override {
|
c@91
|
220
|
c@91
|
221 // Reload the plugin on the server side, and configure it as requested
|
c@91
|
222
|
c@91
|
223 if (!m_transport->isOK()) {
|
c@91
|
224 throw std::runtime_error("Piper server failed to start");
|
c@91
|
225 }
|
c@91
|
226
|
c@91
|
227 if (m_mapper.havePlugin(plugin)) {
|
c@91
|
228 (void)finish(plugin); // server-side unload
|
c@91
|
229 }
|
c@91
|
230
|
c@91
|
231 Vamp::HostExt::PluginStaticData psd;
|
c@91
|
232 Vamp::HostExt::PluginConfiguration defaultConfig;
|
c@91
|
233 PluginHandleMapper::Handle handle =
|
c@91
|
234 serverLoad(plugin->getPluginKey(),
|
c@91
|
235 plugin->getInputSampleRate(),
|
c@91
|
236 plugin->getAdapterFlags(),
|
c@91
|
237 psd, defaultConfig);
|
c@91
|
238
|
c@91
|
239 m_mapper.addPlugin(handle, plugin);
|
c@91
|
240
|
c@91
|
241 (void)configure(plugin, config);
|
c@91
|
242 }
|
c@91
|
243
|
c@90
|
244 private:
|
c@90
|
245 AssignedPluginHandleMapper m_mapper;
|
c@90
|
246 ReqId getId() {
|
c@90
|
247 //!!! todo: mutex
|
c@90
|
248 static ReqId m_nextId = 0;
|
c@90
|
249 return m_nextId++;
|
c@90
|
250 }
|
c@90
|
251
|
c@90
|
252 kj::Array<capnp::word>
|
c@90
|
253 toKJArray(const std::vector<char> &buffer) {
|
c@90
|
254 // We could do this whole thing with fewer copies, but let's
|
c@90
|
255 // see whether it matters first
|
c@90
|
256 size_t wordSize = sizeof(capnp::word);
|
c@90
|
257 size_t words = buffer.size() / wordSize;
|
c@90
|
258 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words));
|
c@90
|
259 memcpy(karr.begin(), buffer.data(), words * wordSize);
|
c@90
|
260 return karr;
|
c@90
|
261 }
|
c@90
|
262
|
c@90
|
263 void
|
c@90
|
264 checkResponseType(const RpcResponse::Reader &r,
|
c@90
|
265 RpcResponse::Response::Which type,
|
c@90
|
266 ReqId id) {
|
c@90
|
267
|
c@90
|
268 if (r.getResponse().which() != type) {
|
c@90
|
269 throw std::runtime_error("Wrong response type");
|
c@90
|
270 }
|
c@90
|
271 if (ReqId(r.getId().getNumber()) != id) {
|
c@90
|
272 throw std::runtime_error("Wrong response id");
|
c@90
|
273 }
|
c@90
|
274 }
|
c@90
|
275
|
c@90
|
276 private:
|
c@90
|
277 SynchronousTransport *m_transport; //!!! I don't own this, but should I?
|
c@90
|
278 };
|
c@90
|
279
|
c@90
|
280 }
|
c@90
|
281
|
c@90
|
282 #endif
|