annotate vamp-client/CapnpClient.h @ 95:b6ac26b72b59

Implement list, use request-response classes in loader
author Chris Cannam <c.cannam@qmul.ac.uk>
date Thu, 13 Oct 2016 14:31:10 +0100
parents a660dca988f8
children
rev   line source
c@94 1
c@94 2 #ifndef PIPER_CAPNP_CLIENT_H
c@94 3 #define PIPER_CAPNP_CLIENT_H
c@94 4
c@94 5 #include "Loader.h"
c@94 6 #include "PluginClient.h"
c@94 7 #include "PluginStub.h"
c@94 8 #include "SynchronousTransport.h"
c@94 9
c@94 10 #include "vamp-support/AssignedPluginHandleMapper.h"
c@94 11 #include "vamp-capnp/VampnProto.h"
c@94 12
c@94 13 #include <capnp/serialize.h>
c@94 14
c@94 15 namespace piper {
c@94 16 namespace vampclient {
c@94 17
c@94 18 class CapnpClient : public PluginClient,
c@94 19 public Loader
c@94 20 {
c@94 21 // unsigned to avoid undefined behaviour on possible wrap
c@94 22 typedef uint32_t ReqId;
c@94 23
c@94 24 class CompletenessChecker : public MessageCompletenessChecker {
c@94 25 public:
c@94 26 bool isComplete(const std::vector<char> &message) const override {
c@94 27 auto karr = toKJArray(message);
c@94 28 size_t words = karr.size();
c@94 29 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr);
c@94 30 if (words > expected) {
c@94 31 std::cerr << "WARNING: obtained more data than expected ("
c@94 32 << words << " " << sizeof(capnp::word)
c@94 33 << "-byte words, expected "
c@94 34 << expected << ")" << std::endl;
c@94 35 }
c@94 36 return words >= expected;
c@94 37 }
c@94 38 };
c@94 39
c@94 40 public:
c@94 41 CapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr?
c@94 42 m_transport(transport),
c@94 43 m_completenessChecker(new CompletenessChecker) {
c@94 44 transport->setCompletenessChecker(m_completenessChecker);
c@94 45 }
c@94 46
c@94 47 ~CapnpClient() {
c@94 48 delete m_completenessChecker;
c@94 49 }
c@94 50
c@94 51 //!!! obviously, factor out all repetitive guff
c@94 52
c@94 53 //!!! list and load are supposed to be called by application code,
c@94 54 //!!! but the rest are only supposed to be called by the plugin --
c@94 55 //!!! sort out the api here
c@95 56
c@95 57 // Loader methods:
c@95 58
c@95 59 Vamp::HostExt::ListResponse
c@95 60 listPluginData() override {
c@94 61
c@94 62 if (!m_transport->isOK()) {
c@94 63 throw std::runtime_error("Piper server failed to start");
c@94 64 }
c@94 65
c@94 66 capnp::MallocMessageBuilder message;
c@94 67 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@95 68 VampnProto::buildRpcRequest_List(builder);
c@94 69 ReqId id = getId();
c@94 70 builder.getId().setNumber(id);
c@94 71
c@95 72 //!!! pure boilerplate:
c@94 73 auto arr = capnp::messageToFlatArray(message);
c@94 74 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@94 75 arr.asChars().size());
c@94 76 auto karr = toKJArray(responseBuffer);
c@94 77 capnp::FlatArrayMessageReader responseMessage(karr);
c@94 78 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@94 79
c@95 80 checkResponseType(reader, RpcResponse::Response::Which::LIST, id);
c@94 81
c@95 82 Vamp::HostExt::ListResponse lr;
c@95 83 VampnProto::readListResponse(lr, reader.getResponse().getList());
c@95 84 return lr;
c@95 85 }
c@95 86
c@95 87 Vamp::HostExt::LoadResponse
c@95 88 loadPlugin(const Vamp::HostExt::LoadRequest &req) override {
c@94 89
c@95 90 if (!m_transport->isOK()) {
c@95 91 throw std::runtime_error("Piper server failed to start");
c@95 92 }
c@95 93
c@95 94 Vamp::HostExt::LoadResponse resp;
c@95 95 PluginHandleMapper::Handle handle = serverLoad(req.pluginKey,
c@95 96 req.inputSampleRate,
c@95 97 req.adapterFlags,
c@95 98 resp.staticData,
c@95 99 resp.defaultConfiguration);
c@95 100
c@95 101 Vamp::Plugin *plugin = new PluginStub(this,
c@95 102 req.pluginKey,
c@95 103 req.inputSampleRate,
c@95 104 req.adapterFlags,
c@95 105 resp.staticData,
c@95 106 resp.defaultConfiguration);
c@95 107
c@95 108 m_mapper.addPlugin(handle, plugin);
c@95 109
c@95 110 resp.plugin = plugin;
c@95 111 return resp;
c@95 112 }
c@95 113
c@95 114 // PluginClient methods:
c@95 115
c@94 116 virtual
c@94 117 Vamp::Plugin::OutputList
c@94 118 configure(PluginStub *plugin,
c@94 119 Vamp::HostExt::PluginConfiguration config) override {
c@94 120
c@94 121 if (!m_transport->isOK()) {
c@94 122 throw std::runtime_error("Piper server failed to start");
c@94 123 }
c@94 124
c@94 125 Vamp::HostExt::ConfigurationRequest request;
c@94 126 request.plugin = plugin;
c@94 127 request.configuration = config;
c@94 128
c@94 129 capnp::MallocMessageBuilder message;
c@94 130 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@94 131
c@94 132 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
c@94 133 ReqId id = getId();
c@94 134 builder.getId().setNumber(id);
c@94 135
c@94 136 auto arr = capnp::messageToFlatArray(message);
c@94 137 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@94 138 arr.asChars().size());
c@94 139 auto karr = toKJArray(responseBuffer);
c@94 140 capnp::FlatArrayMessageReader responseMessage(karr);
c@94 141 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@94 142
c@94 143 //!!! handle (explicit) error case
c@94 144
c@94 145 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id);
c@94 146
c@94 147 Vamp::HostExt::ConfigurationResponse cr;
c@94 148 VampnProto::readConfigurationResponse(cr,
c@94 149 reader.getResponse().getConfigure(),
c@94 150 m_mapper);
c@94 151
c@94 152 return cr.outputs;
c@94 153 };
c@94 154
c@94 155 virtual
c@94 156 Vamp::Plugin::FeatureSet
c@94 157 process(PluginStub *plugin,
c@94 158 std::vector<std::vector<float> > inputBuffers,
c@94 159 Vamp::RealTime timestamp) override {
c@94 160
c@94 161 if (!m_transport->isOK()) {
c@94 162 throw std::runtime_error("Piper server failed to start");
c@94 163 }
c@94 164
c@94 165 Vamp::HostExt::ProcessRequest request;
c@94 166 request.plugin = plugin;
c@94 167 request.inputBuffers = inputBuffers;
c@94 168 request.timestamp = timestamp;
c@94 169
c@94 170 capnp::MallocMessageBuilder message;
c@94 171 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@94 172
c@94 173 VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
c@94 174 ReqId id = getId();
c@94 175 builder.getId().setNumber(id);
c@94 176
c@94 177 auto arr = capnp::messageToFlatArray(message);
c@94 178 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@94 179 arr.asChars().size());
c@94 180 auto karr = toKJArray(responseBuffer);
c@94 181 capnp::FlatArrayMessageReader responseMessage(karr);
c@94 182 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@94 183
c@94 184 //!!! handle (explicit) error case
c@94 185
c@94 186 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id);
c@94 187
c@94 188 Vamp::HostExt::ProcessResponse pr;
c@94 189 VampnProto::readProcessResponse(pr,
c@94 190 reader.getResponse().getProcess(),
c@94 191 m_mapper);
c@94 192
c@94 193 return pr.features;
c@94 194 }
c@94 195
c@94 196 virtual Vamp::Plugin::FeatureSet
c@94 197 finish(PluginStub *plugin) override {
c@94 198
c@94 199 if (!m_transport->isOK()) {
c@94 200 throw std::runtime_error("Piper server failed to start");
c@94 201 }
c@94 202
c@94 203 Vamp::HostExt::FinishRequest request;
c@94 204 request.plugin = plugin;
c@94 205
c@94 206 capnp::MallocMessageBuilder message;
c@94 207 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@94 208
c@94 209 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
c@94 210 ReqId id = getId();
c@94 211 builder.getId().setNumber(id);
c@94 212
c@94 213 auto arr = capnp::messageToFlatArray(message);
c@94 214 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@94 215 arr.asChars().size());
c@94 216 auto karr = toKJArray(responseBuffer);
c@94 217 capnp::FlatArrayMessageReader responseMessage(karr);
c@94 218 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@94 219
c@94 220 //!!! handle (explicit) error case
c@94 221
c@94 222 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id);
c@94 223
c@94 224 Vamp::HostExt::ProcessResponse pr;
c@94 225 VampnProto::readFinishResponse(pr,
c@94 226 reader.getResponse().getFinish(),
c@94 227 m_mapper);
c@94 228
c@94 229 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
c@94 230
c@94 231 // Don't delete the plugin. It's the plugin that is supposed
c@94 232 // to be calling us here
c@94 233
c@94 234 return pr.features;
c@94 235 }
c@94 236
c@94 237 virtual void
c@94 238 reset(PluginStub *plugin,
c@94 239 Vamp::HostExt::PluginConfiguration config) override {
c@94 240
c@94 241 // Reload the plugin on the server side, and configure it as requested
c@94 242
c@94 243 if (!m_transport->isOK()) {
c@94 244 throw std::runtime_error("Piper server failed to start");
c@94 245 }
c@94 246
c@94 247 if (m_mapper.havePlugin(plugin)) {
c@94 248 (void)finish(plugin); // server-side unload
c@94 249 }
c@94 250
c@94 251 Vamp::HostExt::PluginStaticData psd;
c@94 252 Vamp::HostExt::PluginConfiguration defaultConfig;
c@94 253 PluginHandleMapper::Handle handle =
c@94 254 serverLoad(plugin->getPluginKey(),
c@94 255 plugin->getInputSampleRate(),
c@94 256 plugin->getAdapterFlags(),
c@94 257 psd, defaultConfig);
c@94 258
c@94 259 m_mapper.addPlugin(handle, plugin);
c@94 260
c@94 261 (void)configure(plugin, config);
c@94 262 }
c@94 263
c@94 264 private:
c@94 265 AssignedPluginHandleMapper m_mapper;
c@94 266 ReqId getId() {
c@94 267 //!!! todo: mutex
c@94 268 static ReqId m_nextId = 0;
c@94 269 return m_nextId++;
c@94 270 }
c@94 271
c@94 272 static
c@94 273 kj::Array<capnp::word>
c@94 274 toKJArray(const std::vector<char> &buffer) {
c@94 275 // We could do this whole thing with fewer copies, but let's
c@94 276 // see whether it matters first
c@94 277 size_t wordSize = sizeof(capnp::word);
c@94 278 size_t words = buffer.size() / wordSize;
c@94 279 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words));
c@94 280 memcpy(karr.begin(), buffer.data(), words * wordSize);
c@94 281 return karr;
c@94 282 }
c@94 283
c@94 284 void
c@94 285 checkResponseType(const RpcResponse::Reader &r,
c@94 286 RpcResponse::Response::Which type,
c@94 287 ReqId id) {
c@94 288
c@94 289 if (r.getResponse().which() != type) {
c@94 290 throw std::runtime_error("Wrong response type");
c@94 291 }
c@94 292 if (ReqId(r.getId().getNumber()) != id) {
c@94 293 throw std::runtime_error("Wrong response id");
c@94 294 }
c@94 295 }
c@95 296
c@95 297 PluginHandleMapper::Handle
c@95 298 serverLoad(std::string key, float inputSampleRate, int adapterFlags,
c@95 299 Vamp::HostExt::PluginStaticData &psd,
c@95 300 Vamp::HostExt::PluginConfiguration &defaultConfig) {
c@95 301
c@95 302 Vamp::HostExt::LoadRequest request;
c@95 303 request.pluginKey = key;
c@95 304 request.inputSampleRate = inputSampleRate;
c@95 305 request.adapterFlags = adapterFlags;
c@95 306
c@95 307 capnp::MallocMessageBuilder message;
c@95 308 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@95 309
c@95 310 VampnProto::buildRpcRequest_Load(builder, request);
c@95 311 ReqId id = getId();
c@95 312 builder.getId().setNumber(id);
c@95 313
c@95 314 auto arr = capnp::messageToFlatArray(message);
c@95 315
c@95 316 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@95 317 arr.asChars().size());
c@95 318
c@95 319 //!!! ... --> will also need some way to kill this process
c@95 320 //!!! (from another thread)
c@95 321
c@95 322 auto karr = toKJArray(responseBuffer);
c@95 323 capnp::FlatArrayMessageReader responseMessage(karr);
c@95 324 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@95 325
c@95 326 //!!! handle (explicit) error case
c@95 327
c@95 328 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
c@95 329
c@95 330 const LoadResponse::Reader &lr = reader.getResponse().getLoad();
c@95 331 VampnProto::readExtractorStaticData(psd, lr.getStaticData());
c@95 332 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
c@95 333 return lr.getHandle();
c@95 334 };
c@94 335
c@94 336 private:
c@94 337 SynchronousTransport *m_transport; //!!! I don't own this, but should I?
c@94 338 CompletenessChecker *m_completenessChecker; // I own this
c@94 339 };
c@94 340
c@94 341 }
c@94 342 }
c@94 343
c@94 344 #endif