Mercurial > hg > piper-cpp
view vamp-server/server.cpp @ 116:d15cb1151d76
Add JSON support directly to the server. Had hoped to avoid this (using Capnp as canonical in the server and then converting externally as necessary) but it's just too useful for debugging purposes when bundled with client app
author | Chris Cannam <c.cannam@qmul.ac.uk> |
---|---|
date | Thu, 27 Oct 2016 11:39:41 +0100 |
parents | b418b583fd3b |
children | ff3fd8d1b2dc |
line wrap: on
line source
/* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */ #include "vamp-json/VampJson.h" #include "vamp-capnp/VampnProto.h" #include "vamp-support/RequestOrResponse.h" #include "vamp-support/CountingPluginHandleMapper.h" #include "vamp-support/LoaderRequests.h" #include <iostream> #include <sstream> #include <stdexcept> #include <capnp/serialize.h> #include <map> #include <set> // pid for logging #ifdef _WIN32 #include <process.h> static int pid = _getpid(); #else #include <unistd.h> static int pid = getpid(); #endif using namespace std; using namespace json11; using namespace piper_vamp; using namespace Vamp; //!!! This could be faster and lighter: // - Use Capnp structures directly rather than converting to vamp-support ones // - Use Vamp C API (vamp.h) directly rather than converting to C++ //!!! Doing the above for process() and finish() alone would be a good start static string myname = "piper-vamp-server"; static void version() { cout << "1.0" << endl; exit(0); } static void usage(bool successful = false) { cerr << "\n" << myname << ": Load and run Vamp plugins in response to Piper messages\n\n" " Usage: " << myname << " [-d] <format>\n" " " << myname << " -v\n" " " << myname << " -h\n\n" " where\n" " <format>: the format to read and write messages in (\"json\" or \"capnp\")\n" " -d: also print debug information to stderr\n" " -v: print version number to stdout and exit\n" " -h: print this text to stderr and exit\n\n" "Expects Piper request messages in either Cap'n Proto or JSON format on stdin,\n" "and writes response messages in the same format to stdout.\n\n"; if (successful) exit(0); else exit(2); } static CountingPluginHandleMapper mapper; static RequestOrResponse::RpcId readId(const piper::RpcRequest::Reader &r) { int number; string tag; switch (r.getId().which()) { case piper::RpcRequest::Id::Which::NUMBER: number = r.getId().getNumber(); return { RequestOrResponse::RpcId::Number, number, "" }; case piper::RpcRequest::Id::Which::TAG: tag = r.getId().getTag(); return { RequestOrResponse::RpcId::Tag, 0, tag }; case piper::RpcRequest::Id::Which::NONE: return { RequestOrResponse::RpcId::Absent, 0, "" }; } return {}; } static void buildId(piper::RpcResponse::Builder &b, const RequestOrResponse::RpcId &id) { switch (id.type) { case RequestOrResponse::RpcId::Number: b.getId().setNumber(id.number); break; case RequestOrResponse::RpcId::Tag: b.getId().setTag(id.tag); break; case RequestOrResponse::RpcId::Absent: b.getId().setNone(); break; } } static RequestOrResponse::RpcId readJsonId(const Json &j) { RequestOrResponse::RpcId id; if (j["id"].is_number()) { id.type = RequestOrResponse::RpcId::Number; id.number = j["id"].number_value(); } else if (j["id"].is_string()) { id.type = RequestOrResponse::RpcId::Tag; id.tag = j["id"].string_value(); } else { id.type = RequestOrResponse::RpcId::Absent; } return id; } static Json writeJsonId(const RequestOrResponse::RpcId &id) { if (id.type == RequestOrResponse::RpcId::Number) { return id.number; } else if (id.type == RequestOrResponse::RpcId::Tag) { return id.tag; } else { return Json(); } } static Json convertRequestJson(string input, string &err) { Json j = Json::parse(input, err); if (err != "") { err = "invalid json: " + err; return {}; } if (!j.is_object()) { err = "object expected at top level"; } else if (!j["method"].is_string()) { err = "string expected for method field"; } else if (!j["params"].is_null() && !j["params"].is_object()) { err = "object expected for params field"; } return j; } RequestOrResponse readRequestJson(string &err) { RequestOrResponse rr; rr.direction = RequestOrResponse::Request; string input; if (!getline(cin, input)) { // the EOF case, not actually an error rr.type = RRType::NotValid; return rr; } Json j = convertRequestJson(input, err); if (err != "") return {}; rr.type = VampJson::getRequestResponseType(j, err); if (err != "") return {}; rr.id = readJsonId(j); VampJson::BufferSerialisation serialisation = VampJson::BufferSerialisation::Array; switch (rr.type) { case RRType::List: VampJson::toRpcRequest_List(j, err); // type check only break; case RRType::Load: rr.loadRequest = VampJson::toRpcRequest_Load(j, err); break; case RRType::Configure: rr.configurationRequest = VampJson::toRpcRequest_Configure(j, mapper, err); break; case RRType::Process: rr.processRequest = VampJson::toRpcRequest_Process(j, mapper, serialisation, err); break; case RRType::Finish: rr.finishRequest = VampJson::toRpcRequest_Finish(j, mapper, err); break; case RRType::NotValid: break; } return rr; } void writeResponseJson(RequestOrResponse &rr, bool useBase64) { Json j; VampJson::BufferSerialisation serialisation = (useBase64 ? VampJson::BufferSerialisation::Base64 : VampJson::BufferSerialisation::Array); Json id = writeJsonId(rr.id); if (!rr.success) { j = VampJson::fromError(rr.errorText, rr.type, id); } else { switch (rr.type) { case RRType::List: j = VampJson::fromRpcResponse_List(rr.listResponse, id); break; case RRType::Load: j = VampJson::fromRpcResponse_Load(rr.loadResponse, mapper, id); break; case RRType::Configure: j = VampJson::fromRpcResponse_Configure(rr.configurationResponse, mapper, id); break; case RRType::Process: j = VampJson::fromRpcResponse_Process (rr.processResponse, mapper, serialisation, id); break; case RRType::Finish: j = VampJson::fromRpcResponse_Finish (rr.finishResponse, mapper, serialisation, id); break; case RRType::NotValid: break; } } cout << j.dump() << endl; } void writeExceptionJson(const std::exception &e, RRType type) { Json j = VampJson::fromError(e.what(), type, Json()); cout << j.dump() << endl; } RequestOrResponse readRequestCapnp() { RequestOrResponse rr; rr.direction = RequestOrResponse::Request; static kj::FdInputStream stream(0); // stdin static kj::BufferedInputStreamWrapper buffered(stream); if (buffered.tryGetReadBuffer() == nullptr) { rr.type = RRType::NotValid; return rr; } capnp::InputStreamMessageReader message(buffered); piper::RpcRequest::Reader reader = message.getRoot<piper::RpcRequest>(); rr.type = VampnProto::getRequestResponseType(reader); rr.id = readId(reader); switch (rr.type) { case RRType::List: VampnProto::readRpcRequest_List(reader); // type check only break; case RRType::Load: VampnProto::readRpcRequest_Load(rr.loadRequest, reader); break; case RRType::Configure: VampnProto::readRpcRequest_Configure(rr.configurationRequest, reader, mapper); break; case RRType::Process: VampnProto::readRpcRequest_Process(rr.processRequest, reader, mapper); break; case RRType::Finish: VampnProto::readRpcRequest_Finish(rr.finishRequest, reader, mapper); break; case RRType::NotValid: break; } return rr; } void writeResponseCapnp(RequestOrResponse &rr) { capnp::MallocMessageBuilder message; piper::RpcResponse::Builder builder = message.initRoot<piper::RpcResponse>(); buildId(builder, rr.id); if (!rr.success) { VampnProto::buildRpcResponse_Error(builder, rr.errorText, rr.type); } else { switch (rr.type) { case RRType::List: VampnProto::buildRpcResponse_List(builder, rr.listResponse); break; case RRType::Load: VampnProto::buildRpcResponse_Load(builder, rr.loadResponse, mapper); break; case RRType::Configure: VampnProto::buildRpcResponse_Configure(builder, rr.configurationResponse, mapper); break; case RRType::Process: VampnProto::buildRpcResponse_Process(builder, rr.processResponse, mapper); break; case RRType::Finish: VampnProto::buildRpcResponse_Finish(builder, rr.finishResponse, mapper); break; case RRType::NotValid: break; } } writeMessageToFd(1, message); } void writeExceptionCapnp(const std::exception &e, RRType type) { capnp::MallocMessageBuilder message; piper::RpcResponse::Builder builder = message.initRoot<piper::RpcResponse>(); VampnProto::buildRpcResponse_Exception(builder, e, type); writeMessageToFd(1, message); } RequestOrResponse handleRequest(const RequestOrResponse &request, bool debug) { RequestOrResponse response; response.direction = RequestOrResponse::Response; response.type = request.type; switch (request.type) { case RRType::List: response.listResponse = LoaderRequests().listPluginData(); response.success = true; break; case RRType::Load: response.loadResponse = LoaderRequests().loadPlugin(request.loadRequest); if (response.loadResponse.plugin != nullptr) { mapper.addPlugin(response.loadResponse.plugin); if (debug) { cerr << "piper-vamp-server " << pid << ": loaded plugin, handle = " << mapper.pluginToHandle(response.loadResponse.plugin) << endl; } response.success = true; } break; case RRType::Configure: { auto &creq = request.configurationRequest; auto h = mapper.pluginToHandle(creq.plugin); if (mapper.isConfigured(h)) { throw runtime_error("plugin has already been configured"); } response.configurationResponse = LoaderRequests().configurePlugin(creq); if (!response.configurationResponse.outputs.empty()) { mapper.markConfigured (h, creq.configuration.channelCount, creq.configuration.blockSize); response.success = true; } break; } case RRType::Process: { auto &preq = request.processRequest; auto h = mapper.pluginToHandle(preq.plugin); if (!mapper.isConfigured(h)) { throw runtime_error("plugin has not been configured"); } int channels = int(preq.inputBuffers.size()); if (channels != mapper.getChannelCount(h)) { throw runtime_error("wrong number of channels supplied to process"); } const float **fbuffers = new const float *[channels]; for (int i = 0; i < channels; ++i) { if (int(preq.inputBuffers[i].size()) != mapper.getBlockSize(h)) { delete[] fbuffers; throw runtime_error("wrong block size supplied to process"); } fbuffers[i] = preq.inputBuffers[i].data(); } response.processResponse.plugin = preq.plugin; response.processResponse.features = preq.plugin->process(fbuffers, preq.timestamp); response.success = true; delete[] fbuffers; break; } case RRType::Finish: { auto &freq = request.finishRequest; response.finishResponse.plugin = freq.plugin; auto h = mapper.pluginToHandle(freq.plugin); // Finish can be called (to unload the plugin) even if the // plugin has never been configured or used. But we want to // make sure we call getRemainingFeatures only if we have // actually configured the plugin. if (mapper.isConfigured(h)) { response.finishResponse.features = freq.plugin->getRemainingFeatures(); } // We do not delete the plugin here -- we need it in the // mapper when converting the features. It gets deleted in the // calling function. response.success = true; break; } case RRType::NotValid: break; } return response; } RequestOrResponse readRequest(string format) { if (format == "capnp") { return readRequestCapnp(); } else if (format == "json") { string err; auto result = readRequestJson(err); if (err != "") throw runtime_error(err); else return result; } else { throw runtime_error("unknown input format \"" + format + "\""); } } void writeResponse(string format, RequestOrResponse &rr) { if (format == "capnp") { writeResponseCapnp(rr); } else if (format == "json") { writeResponseJson(rr, false); } else { throw runtime_error("unknown output format \"" + format + "\""); } } void writeException(string format, const std::exception &e, RRType type) { if (format == "capnp") { writeExceptionCapnp(e, type); } else if (format == "json") { writeExceptionJson(e, type); } else { throw runtime_error("unknown output format \"" + format + "\""); } } int main(int argc, char **argv) { if (argc != 2 && argc != 3) { usage(); } bool debug = false; string arg = argv[1]; if (arg == "-h") { if (argc == 2) { usage(true); } else { usage(); } } else if (arg == "-v") { if (argc == 2) { version(); } else { usage(); } } else if (arg == "-d") { if (argc == 2) { usage(); } else { debug = true; arg = argv[2]; } } string format = arg; if (format != "capnp" && format != "json") { usage(); } if (debug) { cerr << myname << " " << pid << ": waiting for format: " << format << endl; } while (true) { RequestOrResponse request; try { request = readRequest(format); // NotValid without an exception indicates EOF: if (request.type == RRType::NotValid) { if (debug) { cerr << myname << " " << pid << ": eof reached, exiting" << endl; } break; } if (debug) { cerr << myname << " " << pid << ": request received, of type " << int(request.type) << endl; } RequestOrResponse response = handleRequest(request, debug); response.id = request.id; if (debug) { cerr << myname << " " << pid << ": request handled, writing response" << endl; } writeResponse(format, response); if (debug) { cerr << myname << " " << pid << ": response written" << endl; } if (request.type == RRType::Finish) { auto h = mapper.pluginToHandle(request.finishRequest.plugin); if (debug) { cerr << myname << " " << pid << ": deleting the plugin with handle " << h << endl; } mapper.removePlugin(h); delete request.finishRequest.plugin; } } catch (std::exception &e) { if (debug) { cerr << myname << " " << pid << ": error: " << e.what() << endl; } writeException(format, e, request.type); exit(1); } } exit(0); }