diff vamp-client/client.cpp @ 84:db9a6ab618bc

Client builds; does not run
author Chris Cannam <c.cannam@qmul.ac.uk>
date Wed, 12 Oct 2016 11:59:57 +0100
parents 154e94ea84d4
children 1b7c11bc5a88
line wrap: on
line diff
--- a/vamp-client/client.cpp	Tue Oct 11 17:08:31 2016 +0100
+++ b/vamp-client/client.cpp	Wed Oct 12 11:59:57 2016 +0100
@@ -59,7 +59,7 @@
     }
 
     //!!! obviously, factor out all repetitive guff
-    
+
     Vamp::Plugin *
     load(std::string key, float inputSampleRate, int adapterFlags) {
 
@@ -72,7 +72,7 @@
         request.inputSampleRate = inputSampleRate;
         request.adapterFlags = adapterFlags;
 
-        ::capnp::MallocMessageBuilder message;
+        capnp::MallocMessageBuilder message;
         RpcRequest::Builder builder = message.initRoot<RpcRequest>();
 
         VampnProto::buildRpcRequest_Load(builder, request);
@@ -82,7 +82,32 @@
         auto arr = messageToFlatArray(message);
         m_process->write(arr.asChars().begin(), arr.asChars().size());
 
-        ///.... read...
+        //!!! ... --> will also need some way to kill this process
+        //!!! (from another thread)
+
+        QByteArray buffer = readResponseBuffer();
+        capnp::FlatArrayMessageReader responseMessage(toArrayPtr(buffer));
+        RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
+
+        //!!! handle (explicit) error case
+
+        checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
+        
+        const LoadResponse::Reader &lr = reader.getResponse().getLoad();
+
+        Vamp::HostExt::PluginStaticData psd;
+        Vamp::HostExt::PluginConfiguration defaultConfig;
+        VampnProto::readExtractorStaticData(psd, lr.getStaticData());
+        VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
+        
+        Vamp::Plugin *plugin = new PiperStubPlugin(this,
+                                                   inputSampleRate,
+                                                   psd,
+                                                   defaultConfig);
+
+        m_mapper.addPlugin(lr.getHandle(), plugin);
+
+        return plugin;
     };     
     
     virtual
@@ -98,35 +123,202 @@
         request.plugin = plugin;
         request.configuration = config;
 
-        ::capnp::MallocMessageBuilder message;
+        capnp::MallocMessageBuilder message;
         RpcRequest::Builder builder = message.initRoot<RpcRequest>();
 
         VampnProto::buildRpcRequest_Configure(builder, request, m_mapper);
         ReqId id = getId();
         builder.getId().setNumber(id);
+        
+        auto arr = messageToFlatArray(message);
+        m_process->write(arr.asChars().begin(), arr.asChars().size());
+        
+        QByteArray buffer = readResponseBuffer();
+        capnp::FlatArrayMessageReader responseMessage(toArrayPtr(buffer));
+        RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
 
-        //!!! now what?
+        //!!! handle (explicit) error case
+
+        checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id);
+
+        Vamp::HostExt::ConfigurationResponse cr;
+        VampnProto::readConfigurationResponse(cr,
+                                              reader.getResponse().getConfigure(),
+                                              m_mapper);
+
+        return cr.outputs;
     };
     
-    
     virtual
     Vamp::Plugin::FeatureSet
     process(PiperStubPlugin *plugin,
-            const float *const *inputBuffers,
-            Vamp::RealTime timestamp) = 0;
+            std::vector<std::vector<float> > inputBuffers,
+            Vamp::RealTime timestamp) {
+
+        if (!m_process) {
+            throw std::runtime_error("Piper server failed to start");
+        }
+
+        Vamp::HostExt::ProcessRequest request;
+        request.plugin = plugin;
+        request.inputBuffers = inputBuffers;
+        request.timestamp = timestamp;
+        
+        capnp::MallocMessageBuilder message;
+        RpcRequest::Builder builder = message.initRoot<RpcRequest>();
+
+        VampnProto::buildRpcRequest_Process(builder, request, m_mapper);
+        ReqId id = getId();
+        builder.getId().setNumber(id);
+        
+        auto arr = messageToFlatArray(message);
+        m_process->write(arr.asChars().begin(), arr.asChars().size());
+        
+        QByteArray buffer = readResponseBuffer();
+        capnp::FlatArrayMessageReader responseMessage(toArrayPtr(buffer));
+        RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
+
+        //!!! handle (explicit) error case
+
+        checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id);
+
+        Vamp::HostExt::ProcessResponse pr;
+        VampnProto::readProcessResponse(pr,
+                                        reader.getResponse().getProcess(),
+                                        m_mapper);
+
+        return pr.features;
+    }
 
     virtual Vamp::Plugin::FeatureSet
-    finish(PiperStubPlugin *plugin) = 0;
+    finish(PiperStubPlugin *plugin) {
+
+        if (!m_process) {
+            throw std::runtime_error("Piper server failed to start");
+        }
+
+        Vamp::HostExt::FinishRequest request;
+        request.plugin = plugin;
+        
+        capnp::MallocMessageBuilder message;
+        RpcRequest::Builder builder = message.initRoot<RpcRequest>();
+
+        VampnProto::buildRpcRequest_Finish(builder, request, m_mapper);
+        ReqId id = getId();
+        builder.getId().setNumber(id);
+        
+        auto arr = messageToFlatArray(message);
+        m_process->write(arr.asChars().begin(), arr.asChars().size());
+        
+        QByteArray buffer = readResponseBuffer();
+        capnp::FlatArrayMessageReader responseMessage(toArrayPtr(buffer));
+        RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>();
+
+        //!!! handle (explicit) error case
+
+        checkResponseType(reader, RpcResponse::Response::Which::FINISH, id);
+
+        Vamp::HostExt::ProcessResponse pr;
+        VampnProto::readFinishResponse(pr,
+                                       reader.getResponse().getFinish(),
+                                       m_mapper);
+
+        m_mapper.removePlugin(m_mapper.pluginToHandle(plugin));
+        delete plugin;
+        
+        return pr.features;
+    }
 
 private:
     QProcess *m_process;
     AssignedPluginHandleMapper m_mapper;
-    int getId() {
+    ReqId getId() {
         //!!! todo: mutex
         static ReqId m_nextId = 0;
         return m_nextId++;
     }
+
+    kj::ArrayPtr<const capnp::word>
+    toArrayPtr(QByteArray arr) {
+        size_t wordSize = sizeof(capnp::word);
+        capnp::word *dptr = reinterpret_cast<capnp::word *>(arr.data());
+        kj::ArrayPtr<const capnp::word> kptr(dptr, arr.size() / wordSize);
+        return kptr;
+    }
+
+    QByteArray
+    readResponseBuffer() { 
+        
+        QByteArray buffer;
+        size_t wordSize = sizeof(capnp::word);
+        bool complete = false;
+        
+        while (!complete) {
+
+            m_process->waitForReadyRead(1000);
+            qint64 byteCount = m_process->bytesAvailable();
+            qint64 wordCount = byteCount / wordSize;
+
+            if (!wordCount) {
+                if (m_process->state() == QProcess::NotRunning) {
+                    cerr << "ERROR: Subprocess exited: Load failed" << endl;
+                    throw std::runtime_error("Piper server exited unexpectedly");
+                }
+            } else {
+                buffer.append(m_process->read(wordCount * wordSize));
+                size_t haveWords = buffer.size() / wordSize;
+                size_t expectedWords =
+                    capnp::expectedSizeInWordsFromPrefix(toArrayPtr(buffer));
+
+                cerr << "haveWords = " << haveWords << ", expectedWords = " << expectedWords << endl;
+                
+                if (haveWords >= expectedWords) {
+                    if (haveWords > expectedWords) {
+                        cerr << "WARNING: obtained more data than expected ("
+                             << haveWords << " words, expected " << expectedWords
+                             << ")" << endl;
+                    }
+                    complete = true;
+                }
+            }
+        }
+
+        return buffer;
+    }
+
+    void
+    checkResponseType(const RpcResponse::Reader &r,
+                      RpcResponse::Response::Which type,
+                      ReqId id) {
+        
+        if (r.getResponse().which() != type) {
+            throw std::runtime_error("Wrong response type");
+        }
+        if (ReqId(r.getId().getNumber()) != id) {
+            throw std::runtime_error("Wrong response id");
+        }
+    }
 };
     
 }
 
+int main(int, char **)
+{
+    piper::PiperClient client;
+    Vamp::Plugin *plugin = client.load("vamp-example-plugins:zerocrossing", 16, 0);
+    if (!plugin->initialise(1, 4, 4)) {
+        cerr << "initialisation failed" << endl;
+    } else {
+        std::vector<float> buf = { 1.0, -1.0, 1.0, -1.0 };
+        float *bd = buf.data();
+        Vamp::Plugin::FeatureSet features = plugin->process
+            (&bd, Vamp::RealTime::zeroTime);
+        cerr << "results for output 0:" << endl;
+        auto fl(features[0]);
+        for (const auto &f: fl) {
+            cerr << f.values[0] << endl;
+        }
+    }
+    delete plugin;
+}
+