c@31: c@31: #include "VampnProto.h" c@31: c@31: #include "bits/RequestOrResponse.h" c@31: c@31: #include c@31: #include c@31: #include c@31: c@32: #include c@32: #include c@32: c@31: using namespace std; c@31: using namespace vampipe; c@32: using namespace Vamp; c@32: using namespace Vamp::HostExt; c@31: c@31: void usage() c@31: { c@31: string myname = "vampipe-server"; c@31: cerr << "\n" << myname << c@31: ": Load and run Vamp plugins in response to messages from stdin\n\n" c@31: " Usage: " << myname << "\n\n" c@31: "Expects Vamp request messages in Cap'n Proto packed format on stdin,\n" c@31: "and writes Vamp response messages in the same format to stdout.\n\n"; c@31: c@31: exit(2); c@31: } c@31: c@32: class Mapper : public PluginHandleMapper c@32: { c@32: public: c@32: Mapper() : m_nextHandle(1) { } c@32: c@32: void addPlugin(Plugin *p) { c@32: if (m_rplugins.find(p) == m_rplugins.end()) { c@32: int32_t h = m_nextHandle++; c@32: m_plugins[h] = p; c@32: m_rplugins[p] = h; c@32: } c@32: } c@33: c@33: void removePlugin(int32_t h) { c@33: if (m_plugins.find(h) == m_plugins.end()) { c@33: throw NotFound(); c@33: } c@33: Plugin *p = m_plugins[h]; c@33: m_plugins.erase(h); c@34: if (isConfigured(h)) { c@34: m_configuredPlugins.erase(h); c@34: m_channelCounts.erase(h); c@34: } c@33: m_rplugins.erase(p); c@33: } c@32: c@32: int32_t pluginToHandle(Plugin *p) { c@32: if (m_rplugins.find(p) == m_rplugins.end()) { c@32: throw NotFound(); c@32: } c@32: return m_rplugins[p]; c@32: } c@32: c@32: Plugin *handleToPlugin(int32_t h) { c@32: if (m_plugins.find(h) == m_plugins.end()) { c@32: throw NotFound(); c@32: } c@32: return m_plugins[h]; c@32: } c@32: c@33: bool isConfigured(int32_t h) { c@33: return m_configuredPlugins.find(h) != m_configuredPlugins.end(); c@32: } c@32: c@34: void markConfigured(int32_t h, int channelCount, int blockSize) { c@33: m_configuredPlugins.insert(h); c@34: m_channelCounts[h] = channelCount; c@34: m_blockSizes[h] = blockSize; c@34: } c@34: c@34: int getChannelCount(int32_t h) { c@34: if (m_channelCounts.find(h) == m_channelCounts.end()) { c@34: throw NotFound(); c@34: } c@34: return m_channelCounts[h]; c@34: } c@34: c@34: int getBlockSize(int32_t h) { c@34: if (m_blockSizes.find(h) == m_blockSizes.end()) { c@34: throw NotFound(); c@34: } c@34: return m_blockSizes[h]; c@32: } c@32: c@32: private: c@32: int32_t m_nextHandle; // NB plugin handle type must fit in JSON number c@32: map m_plugins; c@32: map m_rplugins; c@33: set m_configuredPlugins; c@34: map m_channelCounts; c@34: map m_blockSizes; c@32: }; c@32: c@32: static Mapper mapper; c@32: c@31: RequestOrResponse c@31: readRequestCapnp() c@31: { c@31: RequestOrResponse rr; c@31: rr.direction = RequestOrResponse::Request; c@31: c@33: static kj::FdInputStream stream(0); // stdin c@33: static kj::BufferedInputStreamWrapper buffered(stream); c@33: c@33: if (buffered.tryGetReadBuffer() == nullptr) { c@33: rr.type = RRType::NotValid; c@33: return rr; c@33: } c@33: c@33: ::capnp::InputStreamMessageReader message(buffered); c@31: VampRequest::Reader reader = message.getRoot(); c@31: c@31: rr.type = VampnProto::getRequestResponseType(reader); c@31: c@31: switch (rr.type) { c@31: c@31: case RRType::List: c@31: VampnProto::readVampRequest_List(reader); // type check only c@31: break; c@31: case RRType::Load: c@31: VampnProto::readVampRequest_Load(rr.loadRequest, reader); c@31: break; c@31: case RRType::Configure: c@32: VampnProto::readVampRequest_Configure(rr.configurationRequest, c@32: reader, mapper); c@31: break; c@31: case RRType::Process: c@32: VampnProto::readVampRequest_Process(rr.processRequest, reader, mapper); c@31: break; c@31: case RRType::Finish: c@32: VampnProto::readVampRequest_Finish(rr.finishPlugin, reader, mapper); c@31: break; c@31: case RRType::NotValid: c@31: break; c@31: } c@31: c@31: return rr; c@31: } c@31: c@31: void c@31: writeResponseCapnp(RequestOrResponse &rr) c@31: { c@31: ::capnp::MallocMessageBuilder message; c@31: VampResponse::Builder builder = message.initRoot(); c@31: c@31: switch (rr.type) { c@31: c@31: case RRType::List: c@31: VampnProto::buildVampResponse_List(builder, "", rr.listResponse); c@31: break; c@31: case RRType::Load: c@32: VampnProto::buildVampResponse_Load(builder, rr.loadResponse, mapper); c@31: break; c@31: case RRType::Configure: c@31: VampnProto::buildVampResponse_Configure(builder, rr.configurationResponse); c@31: break; c@31: case RRType::Process: c@31: VampnProto::buildVampResponse_Process(builder, rr.processResponse); c@31: break; c@31: case RRType::Finish: c@31: VampnProto::buildVampResponse_Finish(builder, rr.finishResponse); c@31: break; c@31: case RRType::NotValid: c@31: break; c@31: } c@31: c@33: writeMessageToFd(1, message); c@31: } c@31: c@31: RequestOrResponse c@34: handleRequest(const RequestOrResponse &request) c@31: { c@31: RequestOrResponse response; c@31: response.direction = RequestOrResponse::Response; c@32: response.type = request.type; c@32: c@32: auto loader = PluginLoader::getInstance(); c@32: c@32: switch (request.type) { c@32: c@32: case RRType::List: c@32: response.listResponse = loader->listPluginData(); c@32: response.success = true; c@32: break; c@32: c@32: case RRType::Load: c@32: response.loadResponse = loader->loadPlugin(request.loadRequest); c@32: if (response.loadResponse.plugin != nullptr) { c@32: mapper.addPlugin(response.loadResponse.plugin); c@32: response.success = true; c@32: } c@32: break; c@32: c@33: case RRType::Configure: c@33: { c@34: auto &creq = request.configurationRequest; c@34: auto h = mapper.pluginToHandle(creq.plugin); c@33: if (mapper.isConfigured(h)) { c@33: throw runtime_error("plugin has already been configured"); c@33: } c@33: c@34: response.configurationResponse = loader->configurePlugin(creq); c@33: c@33: if (!response.configurationResponse.outputs.empty()) { c@34: mapper.markConfigured c@34: (h, creq.configuration.channelCount, creq.configuration.blockSize); c@33: response.success = true; c@33: } c@33: break; c@33: } c@33: c@33: case RRType::Process: c@33: { c@33: auto &preq = request.processRequest; c@34: auto h = mapper.pluginToHandle(preq.plugin); c@34: if (!mapper.isConfigured(h)) { c@34: throw runtime_error("plugin has not been configured"); c@34: } c@34: c@33: int channels = int(preq.inputBuffers.size()); c@34: if (channels != mapper.getChannelCount(h)) { c@34: throw runtime_error("wrong number of channels supplied to process"); c@34: } c@34: c@33: const float **fbuffers = new const float *[channels]; c@33: for (int i = 0; i < channels; ++i) { c@34: if (int(preq.inputBuffers[i].size()) != mapper.getBlockSize(h)) { c@34: delete[] fbuffers; c@34: throw runtime_error("wrong block size supplied to process"); c@34: } c@33: fbuffers[i] = preq.inputBuffers[i].data(); c@33: } c@33: c@33: response.processResponse.features = c@33: preq.plugin->process(fbuffers, preq.timestamp); c@33: response.success = true; c@33: c@33: delete[] fbuffers; c@33: break; c@33: } c@33: c@33: case RRType::Finish: c@33: { c@33: auto h = mapper.pluginToHandle(request.finishPlugin); c@33: c@33: response.finishResponse.features = c@33: request.finishPlugin->getRemainingFeatures(); c@33: c@33: mapper.removePlugin(h); c@33: delete request.finishPlugin; c@33: response.success = true; c@33: break; c@33: } c@33: c@33: case RRType::NotValid: c@33: break; c@32: } c@32: c@31: return response; c@31: } c@31: c@31: int main(int argc, char **argv) c@31: { c@31: if (argc != 1) { c@31: usage(); c@31: } c@31: c@31: while (true) { c@31: c@31: try { c@31: c@31: RequestOrResponse request = readRequestCapnp(); c@31: c@33: cerr << "vampipe-server: request received, of type " c@33: << int(request.type) c@33: << endl; c@33: c@31: // NotValid without an exception indicates EOF: c@33: if (request.type == RRType::NotValid) { c@33: cerr << "vampipe-server: eof reached" << endl; c@33: break; c@33: } c@31: c@34: RequestOrResponse response = handleRequest(request); c@33: c@34: cerr << "vampipe-server: request handled, writing response" c@33: << endl; c@31: c@31: writeResponseCapnp(response); c@33: c@33: cerr << "vampipe-server: response written" << endl; c@31: c@31: } catch (std::exception &e) { c@33: cerr << "vampipe-server: error: " << e.what() << endl; c@31: exit(1); c@31: } c@31: } c@31: c@31: exit(0); c@31: }