annotate vamp-client/CapnpClient.h @ 94:a660dca988f8

More renaming
author Chris Cannam <c.cannam@qmul.ac.uk>
date Thu, 13 Oct 2016 14:10:55 +0100
parents
children b6ac26b72b59
rev   line source
c@94 1
c@94 2 #ifndef PIPER_CAPNP_CLIENT_H
c@94 3 #define PIPER_CAPNP_CLIENT_H
c@94 4
c@94 5 #include "Loader.h"
c@94 6 #include "PluginClient.h"
c@94 7 #include "PluginStub.h"
c@94 8 #include "SynchronousTransport.h"
c@94 9
c@94 10 #include "vamp-support/AssignedPluginHandleMapper.h"
c@94 11 #include "vamp-capnp/VampnProto.h"
c@94 12
c@94 13 #include <capnp/serialize.h>
c@94 14
c@94 15 namespace piper {
c@94 16 namespace vampclient {
c@94 17
c@94 18 class CapnpClient : public PluginClient,
c@94 19 public Loader
c@94 20 {
c@94 21 // unsigned to avoid undefined behaviour on possible wrap
c@94 22 typedef uint32_t ReqId;
c@94 23
c@94 24 class CompletenessChecker : public MessageCompletenessChecker {
c@94 25 public:
c@94 26 bool isComplete(const std::vector<char> &message) const override {
c@94 27 auto karr = toKJArray(message);
c@94 28 size_t words = karr.size();
c@94 29 size_t expected = capnp::expectedSizeInWordsFromPrefix(karr);
c@94 30 if (words > expected) {
c@94 31 std::cerr << "WARNING: obtained more data than expected ("
c@94 32 << words << " " << sizeof(capnp::word)
c@94 33 << "-byte words, expected "
c@94 34 << expected << ")" << std::endl;
c@94 35 }
c@94 36 return words >= expected;
c@94 37 }
c@94 38 };
c@94 39
c@94 40 public:
c@94 41 CapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr?
c@94 42 m_transport(transport),
c@94 43 m_completenessChecker(new CompletenessChecker) {
c@94 44 transport->setCompletenessChecker(m_completenessChecker);
c@94 45 }
c@94 46
c@94 47 ~CapnpClient() {
c@94 48 delete m_completenessChecker;
c@94 49 }
c@94 50
c@94 51 //!!! obviously, factor out all repetitive guff
c@94 52
c@94 53 //!!! list and load are supposed to be called by application code,
c@94 54 //!!! but the rest are only supposed to be called by the plugin --
c@94 55 //!!! sort out the api here
c@94 56
c@94 57 Vamp::Plugin *
c@94 58 load(std::string key, float inputSampleRate, int adapterFlags) {
c@94 59
c@94 60 if (!m_transport->isOK()) {
c@94 61 throw std::runtime_error("Piper server failed to start");
c@94 62 }
c@94 63
c@94 64 Vamp::HostExt::PluginStaticData psd;
c@94 65 Vamp::HostExt::PluginConfiguration defaultConfig;
c@94 66 PluginHandleMapper::Handle handle =
c@94 67 serverLoad(key, inputSampleRate, adapterFlags, psd, defaultConfig);
c@94 68
c@94 69 Vamp::Plugin *plugin = new PluginStub(this,
c@94 70 key,
c@94 71 inputSampleRate,
c@94 72 adapterFlags,
c@94 73 psd,
c@94 74 defaultConfig);
c@94 75
c@94 76 m_mapper.addPlugin(handle, plugin);
c@94 77
c@94 78 return plugin;
c@94 79 }
c@94 80
c@94 81 PluginHandleMapper::Handle
c@94 82 serverLoad(std::string key, float inputSampleRate, int adapterFlags,
c@94 83 Vamp::HostExt::PluginStaticData &psd,
c@94 84 Vamp::HostExt::PluginConfiguration &defaultConfig) {
c@94 85
c@94 86 Vamp::HostExt::LoadRequest request;
c@94 87 request.pluginKey = key;
c@94 88 request.inputSampleRate = inputSampleRate;
c@94 89 request.adapterFlags = adapterFlags;
c@94 90
c@94 91 capnp::MallocMessageBuilder message;
c@94 92 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@94 93
c@94 94 VampnProto::buildRpcRequest_Load(builder, request);
c@94 95 ReqId id = getId();
c@94 96 builder.getId().setNumber(id);
c@94 97
c@94 98 auto arr = capnp::messageToFlatArray(message);
c@94 99
c@94 100 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@94 101 arr.asChars().size());
c@94 102
c@94 103 //!!! ... --> will also need some way to kill this process
c@94 104 //!!! (from another thread)
c@94 105
c@94 106 auto karr = toKJArray(responseBuffer);
c@94 107 capnp::FlatArrayMessageReader responseMessage(karr);
c@94 108 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@94 109
c@94 110 //!!! handle (explicit) error case
c@94 111
c@94 112 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
c@94 113
c@94 114 const LoadResponse::Reader &lr = reader.getResponse().getLoad();
c@94 115 VampnProto::readExtractorStaticData(psd, lr.getStaticData());
c@94 116 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
c@94 117 return lr.getHandle();
c@94 118 };
c@94 119
c@94 120 protected:
c@94 121 virtual
c@94 122 Vamp::Plugin::OutputList
c@94 123 configure(PluginStub *plugin,
c@94 124 Vamp::HostExt::PluginConfiguration config) override {
c@94 125
c@94 126 if (!m_transport->isOK()) {
c@94 127 throw std::runtime_error("Piper server failed to start");
c@94 128 }
c@94 129
c@94 130 Vamp::HostExt::ConfigurationRequest request;
c@94 131 request.plugin = plugin;
c@94 132 request.configuration = config;
c@94 133
c@94 134 capnp::MallocMessageBuilder message;
c@94 135 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@94 136
c@94 137 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
c@94 138 ReqId id = getId();
c@94 139 builder.getId().setNumber(id);
c@94 140
c@94 141 auto arr = capnp::messageToFlatArray(message);
c@94 142 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@94 143 arr.asChars().size());
c@94 144 auto karr = toKJArray(responseBuffer);
c@94 145 capnp::FlatArrayMessageReader responseMessage(karr);
c@94 146 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@94 147
c@94 148 //!!! handle (explicit) error case
c@94 149
c@94 150 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id);
c@94 151
c@94 152 Vamp::HostExt::ConfigurationResponse cr;
c@94 153 VampnProto::readConfigurationResponse(cr,
c@94 154 reader.getResponse().getConfigure(),
c@94 155 m_mapper);
c@94 156
c@94 157 return cr.outputs;
c@94 158 };
c@94 159
c@94 160 virtual
c@94 161 Vamp::Plugin::FeatureSet
c@94 162 process(PluginStub *plugin,
c@94 163 std::vector<std::vector<float> > inputBuffers,
c@94 164 Vamp::RealTime timestamp) override {
c@94 165
c@94 166 if (!m_transport->isOK()) {
c@94 167 throw std::runtime_error("Piper server failed to start");
c@94 168 }
c@94 169
c@94 170 Vamp::HostExt::ProcessRequest request;
c@94 171 request.plugin = plugin;
c@94 172 request.inputBuffers = inputBuffers;
c@94 173 request.timestamp = timestamp;
c@94 174
c@94 175 capnp::MallocMessageBuilder message;
c@94 176 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@94 177
c@94 178 VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
c@94 179 ReqId id = getId();
c@94 180 builder.getId().setNumber(id);
c@94 181
c@94 182 auto arr = capnp::messageToFlatArray(message);
c@94 183 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@94 184 arr.asChars().size());
c@94 185 auto karr = toKJArray(responseBuffer);
c@94 186 capnp::FlatArrayMessageReader responseMessage(karr);
c@94 187 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@94 188
c@94 189 //!!! handle (explicit) error case
c@94 190
c@94 191 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id);
c@94 192
c@94 193 Vamp::HostExt::ProcessResponse pr;
c@94 194 VampnProto::readProcessResponse(pr,
c@94 195 reader.getResponse().getProcess(),
c@94 196 m_mapper);
c@94 197
c@94 198 return pr.features;
c@94 199 }
c@94 200
c@94 201 virtual Vamp::Plugin::FeatureSet
c@94 202 finish(PluginStub *plugin) override {
c@94 203
c@94 204 if (!m_transport->isOK()) {
c@94 205 throw std::runtime_error("Piper server failed to start");
c@94 206 }
c@94 207
c@94 208 Vamp::HostExt::FinishRequest request;
c@94 209 request.plugin = plugin;
c@94 210
c@94 211 capnp::MallocMessageBuilder message;
c@94 212 RpcRequest::Builder builder = message.initRoot<RpcRequest>();
c@94 213
c@94 214 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
c@94 215 ReqId id = getId();
c@94 216 builder.getId().setNumber(id);
c@94 217
c@94 218 auto arr = capnp::messageToFlatArray(message);
c@94 219 auto responseBuffer = m_transport->call(arr.asChars().begin(),
c@94 220 arr.asChars().size());
c@94 221 auto karr = toKJArray(responseBuffer);
c@94 222 capnp::FlatArrayMessageReader responseMessage(karr);
c@94 223 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
c@94 224
c@94 225 //!!! handle (explicit) error case
c@94 226
c@94 227 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id);
c@94 228
c@94 229 Vamp::HostExt::ProcessResponse pr;
c@94 230 VampnProto::readFinishResponse(pr,
c@94 231 reader.getResponse().getFinish(),
c@94 232 m_mapper);
c@94 233
c@94 234 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
c@94 235
c@94 236 // Don't delete the plugin. It's the plugin that is supposed
c@94 237 // to be calling us here
c@94 238
c@94 239 return pr.features;
c@94 240 }
c@94 241
c@94 242 virtual void
c@94 243 reset(PluginStub *plugin,
c@94 244 Vamp::HostExt::PluginConfiguration config) override {
c@94 245
c@94 246 // Reload the plugin on the server side, and configure it as requested
c@94 247
c@94 248 if (!m_transport->isOK()) {
c@94 249 throw std::runtime_error("Piper server failed to start");
c@94 250 }
c@94 251
c@94 252 if (m_mapper.havePlugin(plugin)) {
c@94 253 (void)finish(plugin); // server-side unload
c@94 254 }
c@94 255
c@94 256 Vamp::HostExt::PluginStaticData psd;
c@94 257 Vamp::HostExt::PluginConfiguration defaultConfig;
c@94 258 PluginHandleMapper::Handle handle =
c@94 259 serverLoad(plugin->getPluginKey(),
c@94 260 plugin->getInputSampleRate(),
c@94 261 plugin->getAdapterFlags(),
c@94 262 psd, defaultConfig);
c@94 263
c@94 264 m_mapper.addPlugin(handle, plugin);
c@94 265
c@94 266 (void)configure(plugin, config);
c@94 267 }
c@94 268
c@94 269 private:
c@94 270 AssignedPluginHandleMapper m_mapper;
c@94 271 ReqId getId() {
c@94 272 //!!! todo: mutex
c@94 273 static ReqId m_nextId = 0;
c@94 274 return m_nextId++;
c@94 275 }
c@94 276
c@94 277 static
c@94 278 kj::Array<capnp::word>
c@94 279 toKJArray(const std::vector<char> &buffer) {
c@94 280 // We could do this whole thing with fewer copies, but let's
c@94 281 // see whether it matters first
c@94 282 size_t wordSize = sizeof(capnp::word);
c@94 283 size_t words = buffer.size() / wordSize;
c@94 284 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words));
c@94 285 memcpy(karr.begin(), buffer.data(), words * wordSize);
c@94 286 return karr;
c@94 287 }
c@94 288
c@94 289 void
c@94 290 checkResponseType(const RpcResponse::Reader &r,
c@94 291 RpcResponse::Response::Which type,
c@94 292 ReqId id) {
c@94 293
c@94 294 if (r.getResponse().which() != type) {
c@94 295 throw std::runtime_error("Wrong response type");
c@94 296 }
c@94 297 if (ReqId(r.getId().getNumber()) != id) {
c@94 298 throw std::runtime_error("Wrong response id");
c@94 299 }
c@94 300 }
c@94 301
c@94 302 private:
c@94 303 SynchronousTransport *m_transport; //!!! I don't own this, but should I?
c@94 304 CompletenessChecker *m_completenessChecker; // I own this
c@94 305 };
c@94 306
c@94 307 }
c@94 308 }
c@94 309
c@94 310 #endif