c@94: c@94: #ifndef PIPER_CAPNP_CLIENT_H c@94: #define PIPER_CAPNP_CLIENT_H c@94: c@94: #include "Loader.h" c@94: #include "PluginClient.h" c@94: #include "PluginStub.h" c@94: #include "SynchronousTransport.h" c@94: c@94: #include "vamp-support/AssignedPluginHandleMapper.h" c@94: #include "vamp-capnp/VampnProto.h" c@94: c@94: #include c@94: c@94: namespace piper { c@94: namespace vampclient { c@94: c@94: class CapnpClient : public PluginClient, c@94: public Loader c@94: { c@94: // unsigned to avoid undefined behaviour on possible wrap c@94: typedef uint32_t ReqId; c@94: c@94: class CompletenessChecker : public MessageCompletenessChecker { c@94: public: c@94: bool isComplete(const std::vector &message) const override { c@94: auto karr = toKJArray(message); c@94: size_t words = karr.size(); c@94: size_t expected = capnp::expectedSizeInWordsFromPrefix(karr); c@94: if (words > expected) { c@94: std::cerr << "WARNING: obtained more data than expected (" c@94: << words << " " << sizeof(capnp::word) c@94: << "-byte words, expected " c@94: << expected << ")" << std::endl; c@94: } c@94: return words >= expected; c@94: } c@94: }; c@94: c@94: public: c@94: CapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr? c@94: m_transport(transport), c@94: m_completenessChecker(new CompletenessChecker) { c@94: transport->setCompletenessChecker(m_completenessChecker); c@94: } c@94: c@94: ~CapnpClient() { c@94: delete m_completenessChecker; c@94: } c@94: c@94: //!!! obviously, factor out all repetitive guff c@94: c@94: //!!! list and load are supposed to be called by application code, c@94: //!!! but the rest are only supposed to be called by the plugin -- c@94: //!!! sort out the api here c@95: c@95: // Loader methods: c@95: c@95: Vamp::HostExt::ListResponse c@95: listPluginData() override { c@94: c@94: if (!m_transport->isOK()) { c@94: throw std::runtime_error("Piper server failed to start"); c@94: } c@94: c@94: capnp::MallocMessageBuilder message; c@94: RpcRequest::Builder builder = message.initRoot(); c@95: VampnProto::buildRpcRequest_List(builder); c@94: ReqId id = getId(); c@94: builder.getId().setNumber(id); c@94: c@95: //!!! pure boilerplate: c@94: auto arr = capnp::messageToFlatArray(message); c@94: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@94: arr.asChars().size()); c@94: auto karr = toKJArray(responseBuffer); c@94: capnp::FlatArrayMessageReader responseMessage(karr); c@94: RpcResponse::Reader reader = responseMessage.getRoot(); c@94: c@95: checkResponseType(reader, RpcResponse::Response::Which::LIST, id); c@94: c@95: Vamp::HostExt::ListResponse lr; c@95: VampnProto::readListResponse(lr, reader.getResponse().getList()); c@95: return lr; c@95: } c@95: c@95: Vamp::HostExt::LoadResponse c@95: loadPlugin(const Vamp::HostExt::LoadRequest &req) override { c@94: c@95: if (!m_transport->isOK()) { c@95: throw std::runtime_error("Piper server failed to start"); c@95: } c@95: c@95: Vamp::HostExt::LoadResponse resp; c@95: PluginHandleMapper::Handle handle = serverLoad(req.pluginKey, c@95: req.inputSampleRate, c@95: req.adapterFlags, c@95: resp.staticData, c@95: resp.defaultConfiguration); c@95: c@95: Vamp::Plugin *plugin = new PluginStub(this, c@95: req.pluginKey, c@95: req.inputSampleRate, c@95: req.adapterFlags, c@95: resp.staticData, c@95: resp.defaultConfiguration); c@95: c@95: m_mapper.addPlugin(handle, plugin); c@95: c@95: resp.plugin = plugin; c@95: return resp; c@95: } c@95: c@95: // PluginClient methods: c@95: c@94: virtual c@94: Vamp::Plugin::OutputList c@94: configure(PluginStub *plugin, c@94: Vamp::HostExt::PluginConfiguration config) override { c@94: c@94: if (!m_transport->isOK()) { c@94: throw std::runtime_error("Piper server failed to start"); c@94: } c@94: c@94: Vamp::HostExt::ConfigurationRequest request; c@94: request.plugin = plugin; c@94: request.configuration = config; c@94: c@94: capnp::MallocMessageBuilder message; c@94: RpcRequest::Builder builder = message.initRoot(); c@94: c@94: VampnProto::buildRpcRequest_Configure(builder, request, m_mapper); c@94: ReqId id = getId(); c@94: builder.getId().setNumber(id); c@94: c@94: auto arr = capnp::messageToFlatArray(message); c@94: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@94: arr.asChars().size()); c@94: auto karr = toKJArray(responseBuffer); c@94: capnp::FlatArrayMessageReader responseMessage(karr); c@94: RpcResponse::Reader reader = responseMessage.getRoot(); c@94: c@94: //!!! handle (explicit) error case c@94: c@94: checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id); c@94: c@94: Vamp::HostExt::ConfigurationResponse cr; c@94: VampnProto::readConfigurationResponse(cr, c@94: reader.getResponse().getConfigure(), c@94: m_mapper); c@94: c@94: return cr.outputs; c@94: }; c@94: c@94: virtual c@94: Vamp::Plugin::FeatureSet c@94: process(PluginStub *plugin, c@94: std::vector > inputBuffers, c@94: Vamp::RealTime timestamp) override { c@94: c@94: if (!m_transport->isOK()) { c@94: throw std::runtime_error("Piper server failed to start"); c@94: } c@94: c@94: Vamp::HostExt::ProcessRequest request; c@94: request.plugin = plugin; c@94: request.inputBuffers = inputBuffers; c@94: request.timestamp = timestamp; c@94: c@94: capnp::MallocMessageBuilder message; c@94: RpcRequest::Builder builder = message.initRoot(); c@94: c@94: VampnProto::buildRpcRequest_Process(builder, request, m_mapper); c@94: ReqId id = getId(); c@94: builder.getId().setNumber(id); c@94: c@94: auto arr = capnp::messageToFlatArray(message); c@94: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@94: arr.asChars().size()); c@94: auto karr = toKJArray(responseBuffer); c@94: capnp::FlatArrayMessageReader responseMessage(karr); c@94: RpcResponse::Reader reader = responseMessage.getRoot(); c@94: c@94: //!!! handle (explicit) error case c@94: c@94: checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id); c@94: c@94: Vamp::HostExt::ProcessResponse pr; c@94: VampnProto::readProcessResponse(pr, c@94: reader.getResponse().getProcess(), c@94: m_mapper); c@94: c@94: return pr.features; c@94: } c@94: c@94: virtual Vamp::Plugin::FeatureSet c@94: finish(PluginStub *plugin) override { c@94: c@94: if (!m_transport->isOK()) { c@94: throw std::runtime_error("Piper server failed to start"); c@94: } c@94: c@94: Vamp::HostExt::FinishRequest request; c@94: request.plugin = plugin; c@94: c@94: capnp::MallocMessageBuilder message; c@94: RpcRequest::Builder builder = message.initRoot(); c@94: c@94: VampnProto::buildRpcRequest_Finish(builder, request, m_mapper); c@94: ReqId id = getId(); c@94: builder.getId().setNumber(id); c@94: c@94: auto arr = capnp::messageToFlatArray(message); c@94: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@94: arr.asChars().size()); c@94: auto karr = toKJArray(responseBuffer); c@94: capnp::FlatArrayMessageReader responseMessage(karr); c@94: RpcResponse::Reader reader = responseMessage.getRoot(); c@94: c@94: //!!! handle (explicit) error case c@94: c@94: checkResponseType(reader, RpcResponse::Response::Which::FINISH, id); c@94: c@94: Vamp::HostExt::ProcessResponse pr; c@94: VampnProto::readFinishResponse(pr, c@94: reader.getResponse().getFinish(), c@94: m_mapper); c@94: c@94: m_mapper.removePlugin(m_mapper.pluginToHandle(plugin)); c@94: c@94: // Don't delete the plugin. It's the plugin that is supposed c@94: // to be calling us here c@94: c@94: return pr.features; c@94: } c@94: c@94: virtual void c@94: reset(PluginStub *plugin, c@94: Vamp::HostExt::PluginConfiguration config) override { c@94: c@94: // Reload the plugin on the server side, and configure it as requested c@94: c@94: if (!m_transport->isOK()) { c@94: throw std::runtime_error("Piper server failed to start"); c@94: } c@94: c@94: if (m_mapper.havePlugin(plugin)) { c@94: (void)finish(plugin); // server-side unload c@94: } c@94: c@94: Vamp::HostExt::PluginStaticData psd; c@94: Vamp::HostExt::PluginConfiguration defaultConfig; c@94: PluginHandleMapper::Handle handle = c@94: serverLoad(plugin->getPluginKey(), c@94: plugin->getInputSampleRate(), c@94: plugin->getAdapterFlags(), c@94: psd, defaultConfig); c@94: c@94: m_mapper.addPlugin(handle, plugin); c@94: c@94: (void)configure(plugin, config); c@94: } c@94: c@94: private: c@94: AssignedPluginHandleMapper m_mapper; c@94: ReqId getId() { c@94: //!!! todo: mutex c@94: static ReqId m_nextId = 0; c@94: return m_nextId++; c@94: } c@94: c@94: static c@94: kj::Array c@94: toKJArray(const std::vector &buffer) { c@94: // We could do this whole thing with fewer copies, but let's c@94: // see whether it matters first c@94: size_t wordSize = sizeof(capnp::word); c@94: size_t words = buffer.size() / wordSize; c@94: kj::Array karr(kj::heapArray(words)); c@94: memcpy(karr.begin(), buffer.data(), words * wordSize); c@94: return karr; c@94: } c@94: c@94: void c@94: checkResponseType(const RpcResponse::Reader &r, c@94: RpcResponse::Response::Which type, c@94: ReqId id) { c@94: c@94: if (r.getResponse().which() != type) { c@94: throw std::runtime_error("Wrong response type"); c@94: } c@94: if (ReqId(r.getId().getNumber()) != id) { c@94: throw std::runtime_error("Wrong response id"); c@94: } c@94: } c@95: c@95: PluginHandleMapper::Handle c@95: serverLoad(std::string key, float inputSampleRate, int adapterFlags, c@95: Vamp::HostExt::PluginStaticData &psd, c@95: Vamp::HostExt::PluginConfiguration &defaultConfig) { c@95: c@95: Vamp::HostExt::LoadRequest request; c@95: request.pluginKey = key; c@95: request.inputSampleRate = inputSampleRate; c@95: request.adapterFlags = adapterFlags; c@95: c@95: capnp::MallocMessageBuilder message; c@95: RpcRequest::Builder builder = message.initRoot(); c@95: c@95: VampnProto::buildRpcRequest_Load(builder, request); c@95: ReqId id = getId(); c@95: builder.getId().setNumber(id); c@95: c@95: auto arr = capnp::messageToFlatArray(message); c@95: c@95: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@95: arr.asChars().size()); c@95: c@95: //!!! ... --> will also need some way to kill this process c@95: //!!! (from another thread) c@95: c@95: auto karr = toKJArray(responseBuffer); c@95: capnp::FlatArrayMessageReader responseMessage(karr); c@95: RpcResponse::Reader reader = responseMessage.getRoot(); c@95: c@95: //!!! handle (explicit) error case c@95: c@95: checkResponseType(reader, RpcResponse::Response::Which::LOAD, id); c@95: c@95: const LoadResponse::Reader &lr = reader.getResponse().getLoad(); c@95: VampnProto::readExtractorStaticData(psd, lr.getStaticData()); c@95: VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration()); c@95: return lr.getHandle(); c@95: }; c@94: c@94: private: c@94: SynchronousTransport *m_transport; //!!! I don't own this, but should I? c@94: CompletenessChecker *m_completenessChecker; // I own this c@94: }; c@94: c@94: } c@94: } c@94: c@94: #endif