c@90: c@90: #ifndef PIPER_CAPNP_CLIENT_H c@90: #define PIPER_CAPNP_CLIENT_H c@90: c@90: #include "PiperClient.h" c@92: #include "PiperPluginStub.h" c@90: #include "SynchronousTransport.h" c@90: c@90: #include "vamp-support/AssignedPluginHandleMapper.h" c@90: #include "vamp-capnp/VampnProto.h" c@90: c@92: #include c@92: c@90: namespace piper { //!!! change c@90: c@92: class PiperCapnpClient : public PiperPluginClientInterface, c@92: public PiperLoaderInterface c@90: { c@90: // unsigned to avoid undefined behaviour on possible wrap c@90: typedef uint32_t ReqId; c@92: c@92: class CompletenessChecker : public MessageCompletenessChecker { c@92: public: c@92: bool isComplete(const std::vector &message) const override { c@92: auto karr = toKJArray(message); c@92: size_t words = karr.size(); c@92: size_t expected = capnp::expectedSizeInWordsFromPrefix(karr); c@92: if (words > expected) { c@92: std::cerr << "WARNING: obtained more data than expected (" c@92: << words << " " << sizeof(capnp::word) c@92: << "-byte words, expected " c@92: << expected << ")" << std::endl; c@92: } c@92: return words >= expected; c@92: } c@92: }; c@90: c@90: public: c@90: PiperCapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr? c@92: m_transport(transport), c@92: m_completenessChecker(new CompletenessChecker) { c@92: transport->setCompletenessChecker(m_completenessChecker); c@90: } c@90: c@90: ~PiperCapnpClient() { c@92: delete m_completenessChecker; c@90: } c@90: c@90: //!!! obviously, factor out all repetitive guff c@90: c@90: //!!! list and load are supposed to be called by application code, c@90: //!!! but the rest are only supposed to be called by the plugin -- c@90: //!!! sort out the api here c@90: c@90: Vamp::Plugin * c@90: load(std::string key, float inputSampleRate, int adapterFlags) { c@90: c@90: if (!m_transport->isOK()) { c@90: throw std::runtime_error("Piper server failed to start"); c@90: } c@90: c@91: Vamp::HostExt::PluginStaticData psd; c@91: Vamp::HostExt::PluginConfiguration defaultConfig; c@91: PluginHandleMapper::Handle handle = c@91: serverLoad(key, inputSampleRate, adapterFlags, psd, defaultConfig); c@91: c@92: Vamp::Plugin *plugin = new PiperPluginStub(this, c@91: key, c@91: inputSampleRate, c@91: adapterFlags, c@91: psd, c@91: defaultConfig); c@91: c@91: m_mapper.addPlugin(handle, plugin); c@91: c@91: return plugin; c@91: } c@91: c@91: PluginHandleMapper::Handle c@91: serverLoad(std::string key, float inputSampleRate, int adapterFlags, c@91: Vamp::HostExt::PluginStaticData &psd, c@91: Vamp::HostExt::PluginConfiguration &defaultConfig) { c@91: c@90: Vamp::HostExt::LoadRequest request; c@90: request.pluginKey = key; c@90: request.inputSampleRate = inputSampleRate; c@90: request.adapterFlags = adapterFlags; c@90: c@90: capnp::MallocMessageBuilder message; c@90: RpcRequest::Builder builder = message.initRoot(); c@90: c@90: VampnProto::buildRpcRequest_Load(builder, request); c@90: ReqId id = getId(); c@90: builder.getId().setNumber(id); c@90: c@91: auto arr = capnp::messageToFlatArray(message); c@90: c@90: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@90: arr.asChars().size()); c@90: c@90: //!!! ... --> will also need some way to kill this process c@90: //!!! (from another thread) c@90: c@90: auto karr = toKJArray(responseBuffer); c@90: capnp::FlatArrayMessageReader responseMessage(karr); c@90: RpcResponse::Reader reader = responseMessage.getRoot(); c@90: c@90: //!!! handle (explicit) error case c@90: c@90: checkResponseType(reader, RpcResponse::Response::Which::LOAD, id); c@90: c@90: const LoadResponse::Reader &lr = reader.getResponse().getLoad(); c@90: VampnProto::readExtractorStaticData(psd, lr.getStaticData()); c@90: VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration()); c@91: return lr.getHandle(); c@90: }; c@90: c@90: protected: c@90: virtual c@90: Vamp::Plugin::OutputList c@92: configure(PiperPluginStub *plugin, c@90: Vamp::HostExt::PluginConfiguration config) override { c@90: c@90: if (!m_transport->isOK()) { c@90: throw std::runtime_error("Piper server failed to start"); c@90: } c@90: c@90: Vamp::HostExt::ConfigurationRequest request; c@90: request.plugin = plugin; c@90: request.configuration = config; c@90: c@90: capnp::MallocMessageBuilder message; c@90: RpcRequest::Builder builder = message.initRoot(); c@90: c@90: VampnProto::buildRpcRequest_Configure(builder, request, m_mapper); c@90: ReqId id = getId(); c@90: builder.getId().setNumber(id); c@90: c@91: auto arr = capnp::messageToFlatArray(message); c@90: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@90: arr.asChars().size()); c@90: auto karr = toKJArray(responseBuffer); c@90: capnp::FlatArrayMessageReader responseMessage(karr); c@90: RpcResponse::Reader reader = responseMessage.getRoot(); c@90: c@90: //!!! handle (explicit) error case c@90: c@90: checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id); c@90: c@90: Vamp::HostExt::ConfigurationResponse cr; c@90: VampnProto::readConfigurationResponse(cr, c@90: reader.getResponse().getConfigure(), c@90: m_mapper); c@90: c@90: return cr.outputs; c@90: }; c@90: c@90: virtual c@90: Vamp::Plugin::FeatureSet c@92: process(PiperPluginStub *plugin, c@90: std::vector > inputBuffers, c@90: Vamp::RealTime timestamp) override { c@90: c@90: if (!m_transport->isOK()) { c@90: throw std::runtime_error("Piper server failed to start"); c@90: } c@90: c@90: Vamp::HostExt::ProcessRequest request; c@90: request.plugin = plugin; c@90: request.inputBuffers = inputBuffers; c@90: request.timestamp = timestamp; c@90: c@90: capnp::MallocMessageBuilder message; c@90: RpcRequest::Builder builder = message.initRoot(); c@90: c@90: VampnProto::buildRpcRequest_Process(builder, request, m_mapper); c@90: ReqId id = getId(); c@90: builder.getId().setNumber(id); c@90: c@91: auto arr = capnp::messageToFlatArray(message); c@90: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@90: arr.asChars().size()); c@90: auto karr = toKJArray(responseBuffer); c@90: capnp::FlatArrayMessageReader responseMessage(karr); c@90: RpcResponse::Reader reader = responseMessage.getRoot(); c@90: c@90: //!!! handle (explicit) error case c@90: c@90: checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id); c@90: c@90: Vamp::HostExt::ProcessResponse pr; c@90: VampnProto::readProcessResponse(pr, c@90: reader.getResponse().getProcess(), c@90: m_mapper); c@90: c@90: return pr.features; c@90: } c@90: c@90: virtual Vamp::Plugin::FeatureSet c@92: finish(PiperPluginStub *plugin) override { c@90: c@90: if (!m_transport->isOK()) { c@90: throw std::runtime_error("Piper server failed to start"); c@90: } c@90: c@90: Vamp::HostExt::FinishRequest request; c@90: request.plugin = plugin; c@90: c@90: capnp::MallocMessageBuilder message; c@90: RpcRequest::Builder builder = message.initRoot(); c@90: c@90: VampnProto::buildRpcRequest_Finish(builder, request, m_mapper); c@90: ReqId id = getId(); c@90: builder.getId().setNumber(id); c@90: c@91: auto arr = capnp::messageToFlatArray(message); c@90: auto responseBuffer = m_transport->call(arr.asChars().begin(), c@90: arr.asChars().size()); c@90: auto karr = toKJArray(responseBuffer); c@90: capnp::FlatArrayMessageReader responseMessage(karr); c@90: RpcResponse::Reader reader = responseMessage.getRoot(); c@90: c@90: //!!! handle (explicit) error case c@90: c@90: checkResponseType(reader, RpcResponse::Response::Which::FINISH, id); c@90: c@90: Vamp::HostExt::ProcessResponse pr; c@90: VampnProto::readFinishResponse(pr, c@90: reader.getResponse().getFinish(), c@90: m_mapper); c@90: c@90: m_mapper.removePlugin(m_mapper.pluginToHandle(plugin)); c@90: c@90: // Don't delete the plugin. It's the plugin that is supposed c@90: // to be calling us here c@90: c@90: return pr.features; c@90: } c@90: c@91: virtual void c@92: reset(PiperPluginStub *plugin, c@91: Vamp::HostExt::PluginConfiguration config) override { c@91: c@91: // Reload the plugin on the server side, and configure it as requested c@91: c@91: if (!m_transport->isOK()) { c@91: throw std::runtime_error("Piper server failed to start"); c@91: } c@91: c@91: if (m_mapper.havePlugin(plugin)) { c@91: (void)finish(plugin); // server-side unload c@91: } c@91: c@91: Vamp::HostExt::PluginStaticData psd; c@91: Vamp::HostExt::PluginConfiguration defaultConfig; c@91: PluginHandleMapper::Handle handle = c@91: serverLoad(plugin->getPluginKey(), c@91: plugin->getInputSampleRate(), c@91: plugin->getAdapterFlags(), c@91: psd, defaultConfig); c@91: c@91: m_mapper.addPlugin(handle, plugin); c@91: c@91: (void)configure(plugin, config); c@91: } c@91: c@90: private: c@90: AssignedPluginHandleMapper m_mapper; c@90: ReqId getId() { c@90: //!!! todo: mutex c@90: static ReqId m_nextId = 0; c@90: return m_nextId++; c@90: } c@90: c@92: static c@90: kj::Array c@90: toKJArray(const std::vector &buffer) { c@90: // We could do this whole thing with fewer copies, but let's c@90: // see whether it matters first c@90: size_t wordSize = sizeof(capnp::word); c@90: size_t words = buffer.size() / wordSize; c@90: kj::Array karr(kj::heapArray(words)); c@90: memcpy(karr.begin(), buffer.data(), words * wordSize); c@90: return karr; c@90: } c@90: c@90: void c@90: checkResponseType(const RpcResponse::Reader &r, c@90: RpcResponse::Response::Which type, c@90: ReqId id) { c@90: c@90: if (r.getResponse().which() != type) { c@90: throw std::runtime_error("Wrong response type"); c@90: } c@90: if (ReqId(r.getId().getNumber()) != id) { c@90: throw std::runtime_error("Wrong response id"); c@90: } c@90: } c@90: c@90: private: c@90: SynchronousTransport *m_transport; //!!! I don't own this, but should I? c@92: CompletenessChecker *m_completenessChecker; // I own this c@90: }; c@90: c@90: } c@90: c@90: #endif