Mercurial > hg > piper-cpp
diff vamp-client/CapnpClient.h @ 94:a660dca988f8
More renaming
author | Chris Cannam <c.cannam@qmul.ac.uk> |
---|---|
date | Thu, 13 Oct 2016 14:10:55 +0100 |
parents | |
children | b6ac26b72b59 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/vamp-client/CapnpClient.h Thu Oct 13 14:10:55 2016 +0100 @@ -0,0 +1,310 @@ + +#ifndef PIPER_CAPNP_CLIENT_H +#define PIPER_CAPNP_CLIENT_H + +#include "Loader.h" +#include "PluginClient.h" +#include "PluginStub.h" +#include "SynchronousTransport.h" + +#include "vamp-support/AssignedPluginHandleMapper.h" +#include "vamp-capnp/VampnProto.h" + +#include <capnp/serialize.h> + +namespace piper { +namespace vampclient { + +class CapnpClient : public PluginClient, + public Loader +{ + // unsigned to avoid undefined behaviour on possible wrap + typedef uint32_t ReqId; + + class CompletenessChecker : public MessageCompletenessChecker { + public: + bool isComplete(const std::vector<char> &message) const override { + auto karr = toKJArray(message); + size_t words = karr.size(); + size_t expected = capnp::expectedSizeInWordsFromPrefix(karr); + if (words > expected) { + std::cerr << "WARNING: obtained more data than expected (" + << words << " " << sizeof(capnp::word) + << "-byte words, expected " + << expected << ")" << std::endl; + } + return words >= expected; + } + }; + +public: + CapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr? + m_transport(transport), + m_completenessChecker(new CompletenessChecker) { + transport->setCompletenessChecker(m_completenessChecker); + } + + ~CapnpClient() { + delete m_completenessChecker; + } + + //!!! obviously, factor out all repetitive guff + + //!!! list and load are supposed to be called by application code, + //!!! but the rest are only supposed to be called by the plugin -- + //!!! sort out the api here + + Vamp::Plugin * + load(std::string key, float inputSampleRate, int adapterFlags) { + + if (!m_transport->isOK()) { + throw std::runtime_error("Piper server failed to start"); + } + + Vamp::HostExt::PluginStaticData psd; + Vamp::HostExt::PluginConfiguration defaultConfig; + PluginHandleMapper::Handle handle = + serverLoad(key, inputSampleRate, adapterFlags, psd, defaultConfig); + + Vamp::Plugin *plugin = new PluginStub(this, + key, + inputSampleRate, + adapterFlags, + psd, + defaultConfig); + + m_mapper.addPlugin(handle, plugin); + + return plugin; + } + + PluginHandleMapper::Handle + serverLoad(std::string key, float inputSampleRate, int adapterFlags, + Vamp::HostExt::PluginStaticData &psd, + Vamp::HostExt::PluginConfiguration &defaultConfig) { + + Vamp::HostExt::LoadRequest request; + request.pluginKey = key; + request.inputSampleRate = inputSampleRate; + request.adapterFlags = adapterFlags; + + capnp::MallocMessageBuilder message; + RpcRequest::Builder builder = message.initRoot<RpcRequest>(); + + VampnProto::buildRpcRequest_Load(builder, request); + ReqId id = getId(); + builder.getId().setNumber(id); + + auto arr = capnp::messageToFlatArray(message); + + auto responseBuffer = m_transport->call(arr.asChars().begin(), + arr.asChars().size()); + + //!!! ... --> will also need some way to kill this process + //!!! (from another thread) + + auto karr = toKJArray(responseBuffer); + capnp::FlatArrayMessageReader responseMessage(karr); + RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); + + //!!! handle (explicit) error case + + checkResponseType(reader, RpcResponse::Response::Which::LOAD, id); + + const LoadResponse::Reader &lr = reader.getResponse().getLoad(); + VampnProto::readExtractorStaticData(psd, lr.getStaticData()); + VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration()); + return lr.getHandle(); + }; + +protected: + virtual + Vamp::Plugin::OutputList + configure(PluginStub *plugin, + Vamp::HostExt::PluginConfiguration config) override { + + if (!m_transport->isOK()) { + throw std::runtime_error("Piper server failed to start"); + } + + Vamp::HostExt::ConfigurationRequest request; + request.plugin = plugin; + request.configuration = config; + + capnp::MallocMessageBuilder message; + RpcRequest::Builder builder = message.initRoot<RpcRequest>(); + + VampnProto::buildRpcRequest_Configure(builder, request, m_mapper); + ReqId id = getId(); + builder.getId().setNumber(id); + + auto arr = capnp::messageToFlatArray(message); + auto responseBuffer = m_transport->call(arr.asChars().begin(), + arr.asChars().size()); + auto karr = toKJArray(responseBuffer); + capnp::FlatArrayMessageReader responseMessage(karr); + RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); + + //!!! handle (explicit) error case + + checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id); + + Vamp::HostExt::ConfigurationResponse cr; + VampnProto::readConfigurationResponse(cr, + reader.getResponse().getConfigure(), + m_mapper); + + return cr.outputs; + }; + + virtual + Vamp::Plugin::FeatureSet + process(PluginStub *plugin, + std::vector<std::vector<float> > inputBuffers, + Vamp::RealTime timestamp) override { + + if (!m_transport->isOK()) { + throw std::runtime_error("Piper server failed to start"); + } + + Vamp::HostExt::ProcessRequest request; + request.plugin = plugin; + request.inputBuffers = inputBuffers; + request.timestamp = timestamp; + + capnp::MallocMessageBuilder message; + RpcRequest::Builder builder = message.initRoot<RpcRequest>(); + + VampnProto::buildRpcRequest_Process(builder, request, m_mapper); + ReqId id = getId(); + builder.getId().setNumber(id); + + auto arr = capnp::messageToFlatArray(message); + auto responseBuffer = m_transport->call(arr.asChars().begin(), + arr.asChars().size()); + auto karr = toKJArray(responseBuffer); + capnp::FlatArrayMessageReader responseMessage(karr); + RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); + + //!!! handle (explicit) error case + + checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id); + + Vamp::HostExt::ProcessResponse pr; + VampnProto::readProcessResponse(pr, + reader.getResponse().getProcess(), + m_mapper); + + return pr.features; + } + + virtual Vamp::Plugin::FeatureSet + finish(PluginStub *plugin) override { + + if (!m_transport->isOK()) { + throw std::runtime_error("Piper server failed to start"); + } + + Vamp::HostExt::FinishRequest request; + request.plugin = plugin; + + capnp::MallocMessageBuilder message; + RpcRequest::Builder builder = message.initRoot<RpcRequest>(); + + VampnProto::buildRpcRequest_Finish(builder, request, m_mapper); + ReqId id = getId(); + builder.getId().setNumber(id); + + auto arr = capnp::messageToFlatArray(message); + auto responseBuffer = m_transport->call(arr.asChars().begin(), + arr.asChars().size()); + auto karr = toKJArray(responseBuffer); + capnp::FlatArrayMessageReader responseMessage(karr); + RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); + + //!!! handle (explicit) error case + + checkResponseType(reader, RpcResponse::Response::Which::FINISH, id); + + Vamp::HostExt::ProcessResponse pr; + VampnProto::readFinishResponse(pr, + reader.getResponse().getFinish(), + m_mapper); + + m_mapper.removePlugin(m_mapper.pluginToHandle(plugin)); + + // Don't delete the plugin. It's the plugin that is supposed + // to be calling us here + + return pr.features; + } + + virtual void + reset(PluginStub *plugin, + Vamp::HostExt::PluginConfiguration config) override { + + // Reload the plugin on the server side, and configure it as requested + + if (!m_transport->isOK()) { + throw std::runtime_error("Piper server failed to start"); + } + + if (m_mapper.havePlugin(plugin)) { + (void)finish(plugin); // server-side unload + } + + Vamp::HostExt::PluginStaticData psd; + Vamp::HostExt::PluginConfiguration defaultConfig; + PluginHandleMapper::Handle handle = + serverLoad(plugin->getPluginKey(), + plugin->getInputSampleRate(), + plugin->getAdapterFlags(), + psd, defaultConfig); + + m_mapper.addPlugin(handle, plugin); + + (void)configure(plugin, config); + } + +private: + AssignedPluginHandleMapper m_mapper; + ReqId getId() { + //!!! todo: mutex + static ReqId m_nextId = 0; + return m_nextId++; + } + + static + kj::Array<capnp::word> + toKJArray(const std::vector<char> &buffer) { + // We could do this whole thing with fewer copies, but let's + // see whether it matters first + size_t wordSize = sizeof(capnp::word); + size_t words = buffer.size() / wordSize; + kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words)); + memcpy(karr.begin(), buffer.data(), words * wordSize); + return karr; + } + + void + checkResponseType(const RpcResponse::Reader &r, + RpcResponse::Response::Which type, + ReqId id) { + + if (r.getResponse().which() != type) { + throw std::runtime_error("Wrong response type"); + } + if (ReqId(r.getId().getNumber()) != id) { + throw std::runtime_error("Wrong response id"); + } + } + +private: + SynchronousTransport *m_transport; //!!! I don't own this, but should I? + CompletenessChecker *m_completenessChecker; // I own this +}; + +} +} + +#endif