changeset 91:c897c9a8daf1

Implement reset()
author Chris Cannam <c.cannam@qmul.ac.uk>
date Thu, 13 Oct 2016 11:33:19 +0100
parents 6429a99abcad
children 21f8af53eaf0
files vamp-client/PiperCapnpClient.h vamp-client/PiperClient.h vamp-client/PiperStubPlugin.h vamp-server/convert.cpp vamp-server/server.cpp vamp-support/AssignedPluginHandleMapper.h
diffstat 6 files changed, 98 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/vamp-client/PiperCapnpClient.h	Thu Oct 13 10:17:59 2016 +0100
+++ b/vamp-client/PiperCapnpClient.h	Thu Oct 13 11:33:19 2016 +0100
@@ -36,6 +36,28 @@
             throw std::runtime_error("Piper server failed to start");
         }
 
+        Vamp::HostExt::PluginStaticData psd;
+        Vamp::HostExt::PluginConfiguration defaultConfig;
+        PluginHandleMapper::Handle handle =
+            serverLoad(key, inputSampleRate, adapterFlags, psd, defaultConfig);
+
+        Vamp::Plugin *plugin = new PiperStubPlugin(this,
+                                                   key,
+                                                   inputSampleRate,
+                                                   adapterFlags,
+                                                   psd,
+                                                   defaultConfig);
+
+        m_mapper.addPlugin(handle, plugin);
+
+        return plugin;
+    }
+    
+    PluginHandleMapper::Handle
+    serverLoad(std::string key, float inputSampleRate, int adapterFlags,
+               Vamp::HostExt::PluginStaticData &psd,
+               Vamp::HostExt::PluginConfiguration &defaultConfig) {
+
         Vamp::HostExt::LoadRequest request;
         request.pluginKey = key;
         request.inputSampleRate = inputSampleRate;
@@ -48,7 +70,7 @@
         ReqId id = getId();
         builder.getId().setNumber(id);
 
-        auto arr = messageToFlatArray(message);
+        auto arr = capnp::messageToFlatArray(message);
 
         auto responseBuffer = m_transport->call(arr.asChars().begin(),
                                                 arr.asChars().size());
@@ -65,20 +87,9 @@
         checkResponseType(reader, RpcResponse::Response::Which::LOAD, id);
         
         const LoadResponse::Reader &lr = reader.getResponse().getLoad();
-
-        Vamp::HostExt::PluginStaticData psd;
-        Vamp::HostExt::PluginConfiguration defaultConfig;
         VampnProto::readExtractorStaticData(psd, lr.getStaticData());
         VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration());
-        
-        Vamp::Plugin *plugin = new PiperStubPlugin(this,
-                                                   inputSampleRate,
-                                                   psd,
-                                                   defaultConfig);
-
-        m_mapper.addPlugin(lr.getHandle(), plugin);
-
-        return plugin;
+        return lr.getHandle();
     };     
 
 protected:
@@ -102,7 +113,7 @@
         ReqId id = getId();
         builder.getId().setNumber(id);
         
-        auto arr = messageToFlatArray(message);
+        auto arr = capnp::messageToFlatArray(message);
         auto responseBuffer = m_transport->call(arr.asChars().begin(),
                                                 arr.asChars().size());
 	auto karr = toKJArray(responseBuffer);
@@ -143,7 +154,7 @@
         ReqId id = getId();
         builder.getId().setNumber(id);
         
-        auto arr = messageToFlatArray(message);
+        auto arr = capnp::messageToFlatArray(message);
         auto responseBuffer = m_transport->call(arr.asChars().begin(),
                                                 arr.asChars().size());
 	auto karr = toKJArray(responseBuffer);
@@ -179,7 +190,7 @@
         ReqId id = getId();
         builder.getId().setNumber(id);
         
-        auto arr = messageToFlatArray(message);
+        auto arr = capnp::messageToFlatArray(message);
         auto responseBuffer = m_transport->call(arr.asChars().begin(),
                                                 arr.asChars().size());
 	auto karr = toKJArray(responseBuffer);
@@ -203,6 +214,33 @@
         return pr.features;
     }
 
+    virtual void
+    reset(PiperStubPlugin *plugin,
+          Vamp::HostExt::PluginConfiguration config) override {
+
+        // Reload the plugin on the server side, and configure it as requested
+        
+        if (!m_transport->isOK()) {
+            throw std::runtime_error("Piper server failed to start");
+        }
+
+        if (m_mapper.havePlugin(plugin)) {
+            (void)finish(plugin); // server-side unload
+        }
+
+        Vamp::HostExt::PluginStaticData psd;
+        Vamp::HostExt::PluginConfiguration defaultConfig;
+        PluginHandleMapper::Handle handle =
+            serverLoad(plugin->getPluginKey(),
+                       plugin->getInputSampleRate(),
+                       plugin->getAdapterFlags(),
+                       psd, defaultConfig);
+
+        m_mapper.addPlugin(handle, plugin);
+
+        (void)configure(plugin, config);
+    }
+    
 private:
     AssignedPluginHandleMapper m_mapper;
     ReqId getId() {
--- a/vamp-client/PiperClient.h	Thu Oct 13 10:17:59 2016 +0100
+++ b/vamp-client/PiperClient.h	Thu Oct 13 11:33:19 2016 +0100
@@ -26,6 +26,11 @@
 
     virtual Vamp::Plugin::FeatureSet
     finish(PiperStubPlugin *plugin) = 0;
+
+    virtual
+    void
+    reset(PiperStubPlugin *plugin,
+          Vamp::HostExt::PluginConfiguration config) = 0;
 };
 
 }
--- a/vamp-client/PiperStubPlugin.h	Thu Oct 13 10:17:59 2016 +0100
+++ b/vamp-client/PiperStubPlugin.h	Thu Oct 13 11:33:19 2016 +0100
@@ -21,11 +21,15 @@
     
 public:
     PiperStubPlugin(PiperStubPluginClientInterface *client,
+                    std::string pluginKey,
                     float inputSampleRate,
+                    int adapterFlags,
                     Vamp::HostExt::PluginStaticData psd,
                     Vamp::HostExt::PluginConfiguration defaultConfig) :
         Plugin(inputSampleRate),
         m_client(client),
+        m_key(pluginKey),
+        m_adapterFlags(adapterFlags),
         m_state(Loaded),
         m_psd(psd),
         m_defaultConfig(defaultConfig),
@@ -37,7 +41,7 @@
 	    (void)m_client->finish(this);
         }
     }
-
+    
     virtual std::string getIdentifier() const {
         return m_psd.basic.identifier;
     }
@@ -119,8 +123,15 @@
     }
 
     virtual void reset() {
-        //!!! hm, how to deal with this? there is no reset() in Piper!
-        throw "Please do not call this function again.";
+        
+        if (m_state == Loaded) {
+            // reset is a no-op if the plugin hasn't been initialised yet
+            return;
+        }
+        
+        m_client->reset(this, m_config);
+
+        m_state = Configured;
     }
 
     virtual InputDomain getInputDomain() const {
@@ -198,9 +209,25 @@
 
         return m_client->finish(this);
     }
+
+    // Not Plugin methods, but needed by the PiperClient to support reloads:
+    
+    virtual float getInputSampleRate() const {
+        return m_inputSampleRate;
+    }
+
+    virtual std::string getPluginKey() const {
+        return m_key;
+    }
+
+    virtual int getAdapterFlags() const {
+        return m_adapterFlags;
+    }
     
 private:
     PiperStubPluginClientInterface *m_client;
+    std::string m_key;
+    int m_adapterFlags;
     State m_state;
     Vamp::HostExt::PluginStaticData m_psd;
     OutputList m_outputs;
--- a/vamp-server/convert.cpp	Thu Oct 13 10:17:59 2016 +0100
+++ b/vamp-server/convert.cpp	Thu Oct 13 11:33:19 2016 +0100
@@ -8,6 +8,8 @@
 #include <sstream>
 #include <stdexcept>
 
+#include <capnp/serialize.h>
+
 using namespace std;
 using namespace json11;
 using namespace piper;
--- a/vamp-server/server.cpp	Thu Oct 13 10:17:59 2016 +0100
+++ b/vamp-server/server.cpp	Thu Oct 13 11:33:19 2016 +0100
@@ -7,6 +7,8 @@
 #include <sstream>
 #include <stdexcept>
 
+#include <capnp/serialize.h>
+
 #include <map>
 #include <set>
 
@@ -175,6 +177,7 @@
 	response.loadResponse = loader->loadPlugin(request.loadRequest);
 	if (response.loadResponse.plugin != nullptr) {
 	    mapper.addPlugin(response.loadResponse.plugin);
+            cerr << "loaded plugin, handle = " << mapper.pluginToHandle(response.loadResponse.plugin) << endl;
 	    response.success = true;
 	}
 	break;
--- a/vamp-support/AssignedPluginHandleMapper.h	Thu Oct 13 10:17:59 2016 +0100
+++ b/vamp-support/AssignedPluginHandleMapper.h	Thu Oct 13 11:33:19 2016 +0100
@@ -77,6 +77,10 @@
 	}
 	m_rplugins.erase(p);
     }
+
+    bool havePlugin(Vamp::Plugin *p) {
+        return (m_rplugins.find(p) != m_rplugins.end());
+    }
     
     Handle pluginToHandle(Vamp::Plugin *p) const noexcept {
 	if (m_rplugins.find(p) == m_rplugins.end()) {