view vamp-server/server.cpp @ 117:5dffc5147176

Small simplification
author Chris Cannam <c.cannam@qmul.ac.uk>
date Thu, 27 Oct 2016 11:40:57 +0100
parents d15cb1151d76
children ff3fd8d1b2dc
line wrap: on
line source
/* -*- c-basic-offset: 4 indent-tabs-mode: nil -*-  vi:set ts=8 sts=4 sw=4: */

#include "vamp-json/VampJson.h"
#include "vamp-capnp/VampnProto.h"
#include "vamp-support/RequestOrResponse.h"
#include "vamp-support/CountingPluginHandleMapper.h"
#include "vamp-support/LoaderRequests.h"

#include <iostream>
#include <sstream>
#include <stdexcept>

#include <capnp/serialize.h>

#include <map>
#include <set>

// pid for logging
#ifdef _WIN32
#include <process.h>
static int pid = _getpid();
#else
#include <unistd.h>
static int pid = getpid();
#endif

using namespace std;
using namespace json11;
using namespace piper_vamp;
using namespace Vamp;

//!!! This could be faster and lighter:
//  - Use Capnp structures directly rather than converting to vamp-support ones
//  - Use Vamp C API (vamp.h) directly rather than converting to C++
//!!! Doing the above for process() and finish() alone would be a good start

static string myname = "piper-vamp-server";

static void version()
{
    cout << "1.0" << endl;
    exit(0);
}

static void usage(bool successful = false)
{
    cerr << "\n" << myname <<
        ": Load and run Vamp plugins in response to Piper messages\n\n"
        "    Usage: " << myname << " [-d] <format>\n"
        "           " << myname << " -v\n"
        "           " << myname << " -h\n\n"
        "    where\n"
        "       <format>: the format to read and write messages in (\"json\" or \"capnp\")\n"
        "       -d: also print debug information to stderr\n"
        "       -v: print version number to stdout and exit\n"
        "       -h: print this text to stderr and exit\n\n"
        "Expects Piper request messages in either Cap'n Proto or JSON format on stdin,\n"
        "and writes response messages in the same format to stdout.\n\n";
    if (successful) exit(0);
    else exit(2);
}

static CountingPluginHandleMapper mapper;

static RequestOrResponse::RpcId
readId(const piper::RpcRequest::Reader &r)
{
    int number;
    string tag;
    switch (r.getId().which()) {
    case piper::RpcRequest::Id::Which::NUMBER:
        number = r.getId().getNumber();
        return { RequestOrResponse::RpcId::Number, number, "" };
    case piper::RpcRequest::Id::Which::TAG:
        tag = r.getId().getTag();
        return { RequestOrResponse::RpcId::Tag, 0, tag };
    case piper::RpcRequest::Id::Which::NONE:
        return { RequestOrResponse::RpcId::Absent, 0, "" };
    }
    return {};
}

static void
buildId(piper::RpcResponse::Builder &b, const RequestOrResponse::RpcId &id)
{
    switch (id.type) {
    case RequestOrResponse::RpcId::Number:
        b.getId().setNumber(id.number);
        break;
    case RequestOrResponse::RpcId::Tag:
        b.getId().setTag(id.tag);
        break;
    case RequestOrResponse::RpcId::Absent:
        b.getId().setNone();
        break;
    }
}

static RequestOrResponse::RpcId
readJsonId(const Json &j)
{
    RequestOrResponse::RpcId id;

    if (j["id"].is_number()) {
        id.type = RequestOrResponse::RpcId::Number;
        id.number = j["id"].number_value();
    } else if (j["id"].is_string()) {
        id.type = RequestOrResponse::RpcId::Tag;
        id.tag = j["id"].string_value();
    } else {
        id.type = RequestOrResponse::RpcId::Absent;
    }

    return id;
}

static Json
writeJsonId(const RequestOrResponse::RpcId &id)
{
    if (id.type == RequestOrResponse::RpcId::Number) {
        return id.number;
    } else if (id.type == RequestOrResponse::RpcId::Tag) {
        return id.tag;
    } else {
        return Json();
    }
}

static Json
convertRequestJson(string input, string &err)
{
    Json j = Json::parse(input, err);
    if (err != "") {
        err = "invalid json: " + err;
        return {};
    }
    if (!j.is_object()) {
        err = "object expected at top level";
    } else if (!j["method"].is_string()) {
        err = "string expected for method field";
    } else if (!j["params"].is_null() && !j["params"].is_object()) {
        err = "object expected for params field";
    }
    return j;
}

RequestOrResponse
readRequestJson(string &err)
{
    RequestOrResponse rr;
    rr.direction = RequestOrResponse::Request;

    string input;
    if (!getline(cin, input)) {
        // the EOF case, not actually an error
        rr.type = RRType::NotValid;
        return rr;
    }
    
    Json j = convertRequestJson(input, err);
    if (err != "") return {};

    rr.type = VampJson::getRequestResponseType(j, err);
    if (err != "") return {};

    rr.id = readJsonId(j);

    VampJson::BufferSerialisation serialisation =
        VampJson::BufferSerialisation::Array;

    switch (rr.type) {

    case RRType::List:
        VampJson::toRpcRequest_List(j, err); // type check only
        break;
    case RRType::Load:
        rr.loadRequest = VampJson::toRpcRequest_Load(j, err);
        break;
    case RRType::Configure:
        rr.configurationRequest = VampJson::toRpcRequest_Configure(j, mapper, err);
        break;
    case RRType::Process:
        rr.processRequest = VampJson::toRpcRequest_Process(j, mapper, serialisation, err);
        break;
    case RRType::Finish:
        rr.finishRequest = VampJson::toRpcRequest_Finish(j, mapper, err);
        break;
    case RRType::NotValid:
        break;
    }

    return rr;
}

void
writeResponseJson(RequestOrResponse &rr, bool useBase64)
{
    Json j;

    VampJson::BufferSerialisation serialisation =
        (useBase64 ?
         VampJson::BufferSerialisation::Base64 :
         VampJson::BufferSerialisation::Array);

    Json id = writeJsonId(rr.id);

    if (!rr.success) {

        j = VampJson::fromError(rr.errorText, rr.type, id);

    } else {
    
        switch (rr.type) {

        case RRType::List:
            j = VampJson::fromRpcResponse_List(rr.listResponse, id);
            break;
        case RRType::Load:
            j = VampJson::fromRpcResponse_Load(rr.loadResponse, mapper, id);
            break;
        case RRType::Configure:
            j = VampJson::fromRpcResponse_Configure(rr.configurationResponse,
                                                    mapper, id);
            break;
        case RRType::Process:
            j = VampJson::fromRpcResponse_Process
                (rr.processResponse, mapper, serialisation, id);
            break;
        case RRType::Finish:
            j = VampJson::fromRpcResponse_Finish
                (rr.finishResponse, mapper, serialisation, id);
            break;
        case RRType::NotValid:
            break;
        }
    }
    
    cout << j.dump() << endl;
}

void
writeExceptionJson(const std::exception &e, RRType type)
{
    Json j = VampJson::fromError(e.what(), type, Json());
    cout << j.dump() << endl;
}

RequestOrResponse
readRequestCapnp()
{
    RequestOrResponse rr;
    rr.direction = RequestOrResponse::Request;

    static kj::FdInputStream stream(0); // stdin
    static kj::BufferedInputStreamWrapper buffered(stream);

    if (buffered.tryGetReadBuffer() == nullptr) {
        rr.type = RRType::NotValid;
        return rr;
    }

    capnp::InputStreamMessageReader message(buffered);
    piper::RpcRequest::Reader reader = message.getRoot<piper::RpcRequest>();
    
    rr.type = VampnProto::getRequestResponseType(reader);
    rr.id = readId(reader);

    switch (rr.type) {

    case RRType::List:
        VampnProto::readRpcRequest_List(reader); // type check only
        break;
    case RRType::Load:
        VampnProto::readRpcRequest_Load(rr.loadRequest, reader);
        break;
    case RRType::Configure:
        VampnProto::readRpcRequest_Configure(rr.configurationRequest,
                                             reader, mapper);
        break;
    case RRType::Process:
        VampnProto::readRpcRequest_Process(rr.processRequest, reader, mapper);
        break;
    case RRType::Finish:
        VampnProto::readRpcRequest_Finish(rr.finishRequest, reader, mapper);
        break;
    case RRType::NotValid:
        break;
    }

    return rr;
}

void
writeResponseCapnp(RequestOrResponse &rr)
{
    capnp::MallocMessageBuilder message;
    piper::RpcResponse::Builder builder = message.initRoot<piper::RpcResponse>();

    buildId(builder, rr.id);
    
    if (!rr.success) {

        VampnProto::buildRpcResponse_Error(builder, rr.errorText, rr.type);

    } else {
        
        switch (rr.type) {

        case RRType::List:
            VampnProto::buildRpcResponse_List(builder, rr.listResponse);
            break;
        case RRType::Load:
            VampnProto::buildRpcResponse_Load(builder, rr.loadResponse, mapper);
            break;
        case RRType::Configure:
            VampnProto::buildRpcResponse_Configure(builder, rr.configurationResponse, mapper);
            break;
        case RRType::Process:
            VampnProto::buildRpcResponse_Process(builder, rr.processResponse, mapper);
            break;
        case RRType::Finish:
            VampnProto::buildRpcResponse_Finish(builder, rr.finishResponse, mapper);
            break;
        case RRType::NotValid:
            break;
        }
    }
    
    writeMessageToFd(1, message);
}

void
writeExceptionCapnp(const std::exception &e, RRType type)
{
    capnp::MallocMessageBuilder message;
    piper::RpcResponse::Builder builder = message.initRoot<piper::RpcResponse>();
    VampnProto::buildRpcResponse_Exception(builder, e, type);
    
    writeMessageToFd(1, message);
}

RequestOrResponse
handleRequest(const RequestOrResponse &request, bool debug)
{
    RequestOrResponse response;
    response.direction = RequestOrResponse::Response;
    response.type = request.type;

    switch (request.type) {

    case RRType::List:
        response.listResponse = LoaderRequests().listPluginData();
        response.success = true;
        break;

    case RRType::Load:
        response.loadResponse = LoaderRequests().loadPlugin(request.loadRequest);
        if (response.loadResponse.plugin != nullptr) {
            mapper.addPlugin(response.loadResponse.plugin);
            if (debug) {
                cerr << "piper-vamp-server " << pid << ": loaded plugin, handle = " << mapper.pluginToHandle(response.loadResponse.plugin) << endl;
            }
            response.success = true;
        }
        break;
        
    case RRType::Configure:
    {
        auto &creq = request.configurationRequest;
        auto h = mapper.pluginToHandle(creq.plugin);
        if (mapper.isConfigured(h)) {
            throw runtime_error("plugin has already been configured");
        }

        response.configurationResponse = LoaderRequests().configurePlugin(creq);
        
        if (!response.configurationResponse.outputs.empty()) {
            mapper.markConfigured
                (h, creq.configuration.channelCount, creq.configuration.blockSize);
            response.success = true;
        }
        break;
    }

    case RRType::Process:
    {
        auto &preq = request.processRequest;
        auto h = mapper.pluginToHandle(preq.plugin);
        if (!mapper.isConfigured(h)) {
            throw runtime_error("plugin has not been configured");
        }

        int channels = int(preq.inputBuffers.size());
        if (channels != mapper.getChannelCount(h)) {
            throw runtime_error("wrong number of channels supplied to process");
        }
                
        const float **fbuffers = new const float *[channels];
        for (int i = 0; i < channels; ++i) {
            if (int(preq.inputBuffers[i].size()) != mapper.getBlockSize(h)) {
                delete[] fbuffers;
                throw runtime_error("wrong block size supplied to process");
            }
            fbuffers[i] = preq.inputBuffers[i].data();
        }

        response.processResponse.plugin = preq.plugin;
        response.processResponse.features =
            preq.plugin->process(fbuffers, preq.timestamp);
        response.success = true;

        delete[] fbuffers;
        break;
    }

    case RRType::Finish:
    {
        auto &freq = request.finishRequest;
        response.finishResponse.plugin = freq.plugin;

        auto h = mapper.pluginToHandle(freq.plugin);
        // Finish can be called (to unload the plugin) even if the
        // plugin has never been configured or used. But we want to
        // make sure we call getRemainingFeatures only if we have
        // actually configured the plugin.
        if (mapper.isConfigured(h)) {
            response.finishResponse.features = freq.plugin->getRemainingFeatures();
        }

        // We do not delete the plugin here -- we need it in the
        // mapper when converting the features. It gets deleted in the
        // calling function.
        response.success = true;
        break;
    }

    case RRType::NotValid:
        break;
    }
    
    return response;
}

RequestOrResponse
readRequest(string format)
{
    if (format == "capnp") {
        return readRequestCapnp();
    } else if (format == "json") {
        string err;
        auto result = readRequestJson(err);
        if (err != "") throw runtime_error(err);
        else return result;
    } else {
        throw runtime_error("unknown input format \"" + format + "\"");
    }
}

void
writeResponse(string format, RequestOrResponse &rr)
{
    if (format == "capnp") {
        writeResponseCapnp(rr);
    } else if (format == "json") {
        writeResponseJson(rr, false);
    } else {
        throw runtime_error("unknown output format \"" + format + "\"");
    }
}

void
writeException(string format, const std::exception &e, RRType type)
{
    if (format == "capnp") {
        writeExceptionCapnp(e, type);
    } else if (format == "json") {
        writeExceptionJson(e, type);
    } else {
        throw runtime_error("unknown output format \"" + format + "\"");
    }
}

int main(int argc, char **argv)
{
    if (argc != 2 && argc != 3) {
        usage();
    }

    bool debug = false;
    
    string arg = argv[1];
    if (arg == "-h") {
        if (argc == 2) {
            usage(true);
        } else {
            usage();
        }
    } else if (arg == "-v") {
        if (argc == 2) {
            version();
        } else {
            usage();
        }
    } else if (arg == "-d") {
        if (argc == 2) {
            usage();
        } else {
            debug = true;
            arg = argv[2];
        }
    }
    
    string format = arg;

    if (format != "capnp" && format != "json") {
        usage();
    }

    if (debug) {
        cerr << myname << " " << pid << ": waiting for format: " << format << endl;
    }

    while (true) {

        RequestOrResponse request;
        
        try {

            request = readRequest(format);
            
            // NotValid without an exception indicates EOF:
            if (request.type == RRType::NotValid) {
                if (debug) {
                    cerr << myname << " " << pid << ": eof reached, exiting" << endl;
                }
                break;
            }

            if (debug) {
                cerr << myname << " " << pid << ": request received, of type "
                     << int(request.type)
                     << endl;
            }

            RequestOrResponse response = handleRequest(request, debug);
            response.id = request.id;

            if (debug) {
                cerr << myname << " " << pid << ": request handled, writing response"
                     << endl;
            }
            
            writeResponse(format, response);

            if (debug) {
                cerr << myname << " " << pid << ": response written" << endl;
            }

            if (request.type == RRType::Finish) {
                auto h = mapper.pluginToHandle(request.finishRequest.plugin);
                if (debug) {
                    cerr << myname << " " << pid << ": deleting the plugin with handle " << h << endl;
                }
                mapper.removePlugin(h);
                delete request.finishRequest.plugin;
            }
            
        } catch (std::exception &e) {

            if (debug) {
                cerr << myname << " " << pid << ": error: " << e.what() << endl;
            }

            writeException(format, e, request.type);
            
            exit(1);
        }
    }

    exit(0);
}