Mercurial > hg > piper-cpp
comparison vamp-client/CapnpClient.h @ 94:a660dca988f8
More renaming
author | Chris Cannam <c.cannam@qmul.ac.uk> |
---|---|
date | Thu, 13 Oct 2016 14:10:55 +0100 |
parents | |
children | b6ac26b72b59 |
comparison
equal
deleted
inserted
replaced
93:fbce91785d35 | 94:a660dca988f8 |
---|---|
1 | |
2 #ifndef PIPER_CAPNP_CLIENT_H | |
3 #define PIPER_CAPNP_CLIENT_H | |
4 | |
5 #include "Loader.h" | |
6 #include "PluginClient.h" | |
7 #include "PluginStub.h" | |
8 #include "SynchronousTransport.h" | |
9 | |
10 #include "vamp-support/AssignedPluginHandleMapper.h" | |
11 #include "vamp-capnp/VampnProto.h" | |
12 | |
13 #include <capnp/serialize.h> | |
14 | |
15 namespace piper { | |
16 namespace vampclient { | |
17 | |
18 class CapnpClient : public PluginClient, | |
19 public Loader | |
20 { | |
21 // unsigned to avoid undefined behaviour on possible wrap | |
22 typedef uint32_t ReqId; | |
23 | |
24 class CompletenessChecker : public MessageCompletenessChecker { | |
25 public: | |
26 bool isComplete(const std::vector<char> &message) const override { | |
27 auto karr = toKJArray(message); | |
28 size_t words = karr.size(); | |
29 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr); | |
30 if (words > expected) { | |
31 std::cerr << "WARNING: obtained more data than expected (" | |
32 << words << " " << sizeof(capnp::word) | |
33 << "-byte words, expected " | |
34 << expected << ")" << std::endl; | |
35 } | |
36 return words >= expected; | |
37 } | |
38 }; | |
39 | |
40 public: | |
41 CapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr? | |
42 m_transport(transport), | |
43 m_completenessChecker(new CompletenessChecker) { | |
44 transport->setCompletenessChecker(m_completenessChecker); | |
45 } | |
46 | |
47 ~CapnpClient() { | |
48 delete m_completenessChecker; | |
49 } | |
50 | |
51 //!!! obviously, factor out all repetitive guff | |
52 | |
53 //!!! list and load are supposed to be called by application code, | |
54 //!!! but the rest are only supposed to be called by the plugin -- | |
55 //!!! sort out the api here | |
56 | |
57 Vamp::Plugin * | |
58 load(std::string key, float inputSampleRate, int adapterFlags) { | |
59 | |
60 if (!m_transport->isOK()) { | |
61 throw std::runtime_error("Piper server failed to start"); | |
62 } | |
63 | |
64 Vamp::HostExt::PluginStaticData psd; | |
65 Vamp::HostExt::PluginConfiguration defaultConfig; | |
66 PluginHandleMapper::Handle handle = | |
67 serverLoad(key, inputSampleRate, adapterFlags, psd, defaultConfig); | |
68 | |
69 Vamp::Plugin *plugin = new PluginStub(this, | |
70 key, | |
71 inputSampleRate, | |
72 adapterFlags, | |
73 psd, | |
74 defaultConfig); | |
75 | |
76 m_mapper.addPlugin(handle, plugin); | |
77 | |
78 return plugin; | |
79 } | |
80 | |
81 PluginHandleMapper::Handle | |
82 serverLoad(std::string key, float inputSampleRate, int adapterFlags, | |
83 Vamp::HostExt::PluginStaticData &psd, | |
84 Vamp::HostExt::PluginConfiguration &defaultConfig) { | |
85 | |
86 Vamp::HostExt::LoadRequest request; | |
87 request.pluginKey = key; | |
88 request.inputSampleRate = inputSampleRate; | |
89 request.adapterFlags = adapterFlags; | |
90 | |
91 capnp::MallocMessageBuilder message; | |
92 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
93 | |
94 VampnProto::buildRpcRequest_Load(builder, request); | |
95 ReqId id = getId(); | |
96 builder.getId().setNumber(id); | |
97 | |
98 auto arr = capnp::messageToFlatArray(message); | |
99 | |
100 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
101 arr.asChars().size()); | |
102 | |
103 //!!! ... --> will also need some way to kill this process | |
104 //!!! (from another thread) | |
105 | |
106 auto karr = toKJArray(responseBuffer); | |
107 capnp::FlatArrayMessageReader responseMessage(karr); | |
108 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
109 | |
110 //!!! handle (explicit) error case | |
111 | |
112 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id); | |
113 | |
114 const LoadResponse::Reader &lr = reader.getResponse().getLoad(); | |
115 VampnProto::readExtractorStaticData(psd, lr.getStaticData()); | |
116 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration()); | |
117 return lr.getHandle(); | |
118 }; | |
119 | |
120 protected: | |
121 virtual | |
122 Vamp::Plugin::OutputList | |
123 configure(PluginStub *plugin, | |
124 Vamp::HostExt::PluginConfiguration config) override { | |
125 | |
126 if (!m_transport->isOK()) { | |
127 throw std::runtime_error("Piper server failed to start"); | |
128 } | |
129 | |
130 Vamp::HostExt::ConfigurationRequest request; | |
131 request.plugin = plugin; | |
132 request.configuration = config; | |
133 | |
134 capnp::MallocMessageBuilder message; | |
135 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
136 | |
137 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper); | |
138 ReqId id = getId(); | |
139 builder.getId().setNumber(id); | |
140 | |
141 auto arr = capnp::messageToFlatArray(message); | |
142 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
143 arr.asChars().size()); | |
144 auto karr = toKJArray(responseBuffer); | |
145 capnp::FlatArrayMessageReader responseMessage(karr); | |
146 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
147 | |
148 //!!! handle (explicit) error case | |
149 | |
150 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id); | |
151 | |
152 Vamp::HostExt::ConfigurationResponse cr; | |
153 VampnProto::readConfigurationResponse(cr, | |
154 reader.getResponse().getConfigure(), | |
155 m_mapper); | |
156 | |
157 return cr.outputs; | |
158 }; | |
159 | |
160 virtual | |
161 Vamp::Plugin::FeatureSet | |
162 process(PluginStub *plugin, | |
163 std::vector<std::vector<float> > inputBuffers, | |
164 Vamp::RealTime timestamp) override { | |
165 | |
166 if (!m_transport->isOK()) { | |
167 throw std::runtime_error("Piper server failed to start"); | |
168 } | |
169 | |
170 Vamp::HostExt::ProcessRequest request; | |
171 request.plugin = plugin; | |
172 request.inputBuffers = inputBuffers; | |
173 request.timestamp = timestamp; | |
174 | |
175 capnp::MallocMessageBuilder message; | |
176 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
177 | |
178 VampnProto::buildRpcRequest_Process(builder, request, m_mapper); | |
179 ReqId id = getId(); | |
180 builder.getId().setNumber(id); | |
181 | |
182 auto arr = capnp::messageToFlatArray(message); | |
183 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
184 arr.asChars().size()); | |
185 auto karr = toKJArray(responseBuffer); | |
186 capnp::FlatArrayMessageReader responseMessage(karr); | |
187 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
188 | |
189 //!!! handle (explicit) error case | |
190 | |
191 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id); | |
192 | |
193 Vamp::HostExt::ProcessResponse pr; | |
194 VampnProto::readProcessResponse(pr, | |
195 reader.getResponse().getProcess(), | |
196 m_mapper); | |
197 | |
198 return pr.features; | |
199 } | |
200 | |
201 virtual Vamp::Plugin::FeatureSet | |
202 finish(PluginStub *plugin) override { | |
203 | |
204 if (!m_transport->isOK()) { | |
205 throw std::runtime_error("Piper server failed to start"); | |
206 } | |
207 | |
208 Vamp::HostExt::FinishRequest request; | |
209 request.plugin = plugin; | |
210 | |
211 capnp::MallocMessageBuilder message; | |
212 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
213 | |
214 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper); | |
215 ReqId id = getId(); | |
216 builder.getId().setNumber(id); | |
217 | |
218 auto arr = capnp::messageToFlatArray(message); | |
219 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
220 arr.asChars().size()); | |
221 auto karr = toKJArray(responseBuffer); | |
222 capnp::FlatArrayMessageReader responseMessage(karr); | |
223 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
224 | |
225 //!!! handle (explicit) error case | |
226 | |
227 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id); | |
228 | |
229 Vamp::HostExt::ProcessResponse pr; | |
230 VampnProto::readFinishResponse(pr, | |
231 reader.getResponse().getFinish(), | |
232 m_mapper); | |
233 | |
234 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin)); | |
235 | |
236 // Don't delete the plugin. It's the plugin that is supposed | |
237 // to be calling us here | |
238 | |
239 return pr.features; | |
240 } | |
241 | |
242 virtual void | |
243 reset(PluginStub *plugin, | |
244 Vamp::HostExt::PluginConfiguration config) override { | |
245 | |
246 // Reload the plugin on the server side, and configure it as requested | |
247 | |
248 if (!m_transport->isOK()) { | |
249 throw std::runtime_error("Piper server failed to start"); | |
250 } | |
251 | |
252 if (m_mapper.havePlugin(plugin)) { | |
253 (void)finish(plugin); // server-side unload | |
254 } | |
255 | |
256 Vamp::HostExt::PluginStaticData psd; | |
257 Vamp::HostExt::PluginConfiguration defaultConfig; | |
258 PluginHandleMapper::Handle handle = | |
259 serverLoad(plugin->getPluginKey(), | |
260 plugin->getInputSampleRate(), | |
261 plugin->getAdapterFlags(), | |
262 psd, defaultConfig); | |
263 | |
264 m_mapper.addPlugin(handle, plugin); | |
265 | |
266 (void)configure(plugin, config); | |
267 } | |
268 | |
269 private: | |
270 AssignedPluginHandleMapper m_mapper; | |
271 ReqId getId() { | |
272 //!!! todo: mutex | |
273 static ReqId m_nextId = 0; | |
274 return m_nextId++; | |
275 } | |
276 | |
277 static | |
278 kj::Array<capnp::word> | |
279 toKJArray(const std::vector<char> &buffer) { | |
280 // We could do this whole thing with fewer copies, but let's | |
281 // see whether it matters first | |
282 size_t wordSize = sizeof(capnp::word); | |
283 size_t words = buffer.size() / wordSize; | |
284 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words)); | |
285 memcpy(karr.begin(), buffer.data(), words * wordSize); | |
286 return karr; | |
287 } | |
288 | |
289 void | |
290 checkResponseType(const RpcResponse::Reader &r, | |
291 RpcResponse::Response::Which type, | |
292 ReqId id) { | |
293 | |
294 if (r.getResponse().which() != type) { | |
295 throw std::runtime_error("Wrong response type"); | |
296 } | |
297 if (ReqId(r.getId().getNumber()) != id) { | |
298 throw std::runtime_error("Wrong response id"); | |
299 } | |
300 } | |
301 | |
302 private: | |
303 SynchronousTransport *m_transport; //!!! I don't own this, but should I? | |
304 CompletenessChecker *m_completenessChecker; // I own this | |
305 }; | |
306 | |
307 } | |
308 } | |
309 | |
310 #endif |