annotate vamp-server/server.cpp @ 117:5dffc5147176

Small simplification
author Chris Cannam <c.cannam@qmul.ac.uk>
date Thu, 27 Oct 2016 11:40:57 +0100
parents d15cb1151d76
children ff3fd8d1b2dc
rev   line source
c@116 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
c@75 2
c@116 3 #include "vamp-json/VampJson.h"
c@75 4 #include "vamp-capnp/VampnProto.h"
c@75 5 #include "vamp-support/RequestOrResponse.h"
c@75 6 #include "vamp-support/CountingPluginHandleMapper.h"
c@97 7 #include "vamp-support/LoaderRequests.h"
c@75 8
c@75 9 #include <iostream>
c@75 10 #include <sstream>
c@75 11 #include <stdexcept>
c@75 12
c@91 13 #include <capnp/serialize.h>
c@91 14
c@75 15 #include <map>
c@75 16 #include <set>
c@75 17
c@109 18 // pid for logging
c@109 19 #ifdef _WIN32
c@109 20 #include <process.h>
c@109 21 static int pid = _getpid();
c@109 22 #else
c@109 23 #include <unistd.h>
c@109 24 static int pid = getpid();
c@109 25 #endif
c@103 26
c@75 27 using namespace std;
c@116 28 using namespace json11;
c@97 29 using namespace piper_vamp;
c@75 30 using namespace Vamp;
c@75 31
c@102 32 //!!! This could be faster and lighter:
c@102 33 // - Use Capnp structures directly rather than converting to vamp-support ones
c@102 34 // - Use Vamp C API (vamp.h) directly rather than converting to C++
c@102 35 //!!! Doing the above for process() and finish() alone would be a good start
c@102 36
c@116 37 static string myname = "piper-vamp-server";
c@116 38
c@116 39 static void version()
c@75 40 {
c@116 41 cout << "1.0" << endl;
c@116 42 exit(0);
c@116 43 }
c@116 44
c@116 45 static void usage(bool successful = false)
c@116 46 {
c@75 47 cerr << "\n" << myname <<
c@116 48 ": Load and run Vamp plugins in response to Piper messages\n\n"
c@116 49 " Usage: " << myname << " [-d] <format>\n"
c@116 50 " " << myname << " -v\n"
c@116 51 " " << myname << " -h\n\n"
c@116 52 " where\n"
c@116 53 " <format>: the format to read and write messages in (\"json\" or \"capnp\")\n"
c@116 54 " -d: also print debug information to stderr\n"
c@116 55 " -v: print version number to stdout and exit\n"
c@116 56 " -h: print this text to stderr and exit\n\n"
c@116 57 "Expects Piper request messages in either Cap'n Proto or JSON format on stdin,\n"
c@116 58 "and writes response messages in the same format to stdout.\n\n";
c@116 59 if (successful) exit(0);
c@116 60 else exit(2);
c@75 61 }
c@75 62
c@75 63 static CountingPluginHandleMapper mapper;
c@75 64
c@116 65 static RequestOrResponse::RpcId
c@116 66 readId(const piper::RpcRequest::Reader &r)
c@75 67 {
c@75 68 int number;
c@75 69 string tag;
c@75 70 switch (r.getId().which()) {
c@97 71 case piper::RpcRequest::Id::Which::NUMBER:
c@75 72 number = r.getId().getNumber();
c@75 73 return { RequestOrResponse::RpcId::Number, number, "" };
c@97 74 case piper::RpcRequest::Id::Which::TAG:
c@75 75 tag = r.getId().getTag();
c@75 76 return { RequestOrResponse::RpcId::Tag, 0, tag };
c@97 77 case piper::RpcRequest::Id::Which::NONE:
c@75 78 return { RequestOrResponse::RpcId::Absent, 0, "" };
c@75 79 }
c@75 80 return {};
c@75 81 }
c@75 82
c@116 83 static void
c@116 84 buildId(piper::RpcResponse::Builder &b, const RequestOrResponse::RpcId &id)
c@75 85 {
c@75 86 switch (id.type) {
c@75 87 case RequestOrResponse::RpcId::Number:
c@75 88 b.getId().setNumber(id.number);
c@75 89 break;
c@75 90 case RequestOrResponse::RpcId::Tag:
c@75 91 b.getId().setTag(id.tag);
c@75 92 break;
c@75 93 case RequestOrResponse::RpcId::Absent:
c@75 94 b.getId().setNone();
c@75 95 break;
c@75 96 }
c@75 97 }
c@75 98
c@116 99 static RequestOrResponse::RpcId
c@116 100 readJsonId(const Json &j)
c@116 101 {
c@116 102 RequestOrResponse::RpcId id;
c@116 103
c@116 104 if (j["id"].is_number()) {
c@116 105 id.type = RequestOrResponse::RpcId::Number;
c@116 106 id.number = j["id"].number_value();
c@116 107 } else if (j["id"].is_string()) {
c@116 108 id.type = RequestOrResponse::RpcId::Tag;
c@116 109 id.tag = j["id"].string_value();
c@116 110 } else {
c@116 111 id.type = RequestOrResponse::RpcId::Absent;
c@116 112 }
c@116 113
c@116 114 return id;
c@116 115 }
c@116 116
c@116 117 static Json
c@116 118 writeJsonId(const RequestOrResponse::RpcId &id)
c@116 119 {
c@116 120 if (id.type == RequestOrResponse::RpcId::Number) {
c@116 121 return id.number;
c@116 122 } else if (id.type == RequestOrResponse::RpcId::Tag) {
c@116 123 return id.tag;
c@116 124 } else {
c@116 125 return Json();
c@116 126 }
c@116 127 }
c@116 128
c@116 129 static Json
c@116 130 convertRequestJson(string input, string &err)
c@116 131 {
c@116 132 Json j = Json::parse(input, err);
c@116 133 if (err != "") {
c@116 134 err = "invalid json: " + err;
c@116 135 return {};
c@116 136 }
c@116 137 if (!j.is_object()) {
c@116 138 err = "object expected at top level";
c@116 139 } else if (!j["method"].is_string()) {
c@116 140 err = "string expected for method field";
c@116 141 } else if (!j["params"].is_null() && !j["params"].is_object()) {
c@116 142 err = "object expected for params field";
c@116 143 }
c@116 144 return j;
c@116 145 }
c@116 146
c@116 147 RequestOrResponse
c@116 148 readRequestJson(string &err)
c@116 149 {
c@116 150 RequestOrResponse rr;
c@116 151 rr.direction = RequestOrResponse::Request;
c@116 152
c@116 153 string input;
c@116 154 if (!getline(cin, input)) {
c@116 155 // the EOF case, not actually an error
c@116 156 rr.type = RRType::NotValid;
c@116 157 return rr;
c@116 158 }
c@116 159
c@116 160 Json j = convertRequestJson(input, err);
c@116 161 if (err != "") return {};
c@116 162
c@116 163 rr.type = VampJson::getRequestResponseType(j, err);
c@116 164 if (err != "") return {};
c@116 165
c@116 166 rr.id = readJsonId(j);
c@116 167
c@116 168 VampJson::BufferSerialisation serialisation =
c@116 169 VampJson::BufferSerialisation::Array;
c@116 170
c@116 171 switch (rr.type) {
c@116 172
c@116 173 case RRType::List:
c@116 174 VampJson::toRpcRequest_List(j, err); // type check only
c@116 175 break;
c@116 176 case RRType::Load:
c@116 177 rr.loadRequest = VampJson::toRpcRequest_Load(j, err);
c@116 178 break;
c@116 179 case RRType::Configure:
c@116 180 rr.configurationRequest = VampJson::toRpcRequest_Configure(j, mapper, err);
c@116 181 break;
c@116 182 case RRType::Process:
c@116 183 rr.processRequest = VampJson::toRpcRequest_Process(j, mapper, serialisation, err);
c@116 184 break;
c@116 185 case RRType::Finish:
c@116 186 rr.finishRequest = VampJson::toRpcRequest_Finish(j, mapper, err);
c@116 187 break;
c@116 188 case RRType::NotValid:
c@116 189 break;
c@116 190 }
c@116 191
c@116 192 return rr;
c@116 193 }
c@116 194
c@116 195 void
c@116 196 writeResponseJson(RequestOrResponse &rr, bool useBase64)
c@116 197 {
c@116 198 Json j;
c@116 199
c@116 200 VampJson::BufferSerialisation serialisation =
c@116 201 (useBase64 ?
c@116 202 VampJson::BufferSerialisation::Base64 :
c@116 203 VampJson::BufferSerialisation::Array);
c@116 204
c@116 205 Json id = writeJsonId(rr.id);
c@116 206
c@116 207 if (!rr.success) {
c@116 208
c@116 209 j = VampJson::fromError(rr.errorText, rr.type, id);
c@116 210
c@116 211 } else {
c@116 212
c@116 213 switch (rr.type) {
c@116 214
c@116 215 case RRType::List:
c@116 216 j = VampJson::fromRpcResponse_List(rr.listResponse, id);
c@116 217 break;
c@116 218 case RRType::Load:
c@116 219 j = VampJson::fromRpcResponse_Load(rr.loadResponse, mapper, id);
c@116 220 break;
c@116 221 case RRType::Configure:
c@116 222 j = VampJson::fromRpcResponse_Configure(rr.configurationResponse,
c@116 223 mapper, id);
c@116 224 break;
c@116 225 case RRType::Process:
c@116 226 j = VampJson::fromRpcResponse_Process
c@116 227 (rr.processResponse, mapper, serialisation, id);
c@116 228 break;
c@116 229 case RRType::Finish:
c@116 230 j = VampJson::fromRpcResponse_Finish
c@116 231 (rr.finishResponse, mapper, serialisation, id);
c@116 232 break;
c@116 233 case RRType::NotValid:
c@116 234 break;
c@116 235 }
c@116 236 }
c@116 237
c@116 238 cout << j.dump() << endl;
c@116 239 }
c@116 240
c@116 241 void
c@116 242 writeExceptionJson(const std::exception &e, RRType type)
c@116 243 {
c@116 244 Json j = VampJson::fromError(e.what(), type, Json());
c@116 245 cout << j.dump() << endl;
c@116 246 }
c@116 247
c@75 248 RequestOrResponse
c@75 249 readRequestCapnp()
c@75 250 {
c@75 251 RequestOrResponse rr;
c@75 252 rr.direction = RequestOrResponse::Request;
c@75 253
c@75 254 static kj::FdInputStream stream(0); // stdin
c@75 255 static kj::BufferedInputStreamWrapper buffered(stream);
c@75 256
c@75 257 if (buffered.tryGetReadBuffer() == nullptr) {
c@116 258 rr.type = RRType::NotValid;
c@116 259 return rr;
c@75 260 }
c@75 261
c@97 262 capnp::InputStreamMessageReader message(buffered);
c@97 263 piper::RpcRequest::Reader reader = message.getRoot<piper::RpcRequest>();
c@75 264
c@75 265 rr.type = VampnProto::getRequestResponseType(reader);
c@75 266 rr.id = readId(reader);
c@75 267
c@75 268 switch (rr.type) {
c@75 269
c@75 270 case RRType::List:
c@116 271 VampnProto::readRpcRequest_List(reader); // type check only
c@116 272 break;
c@75 273 case RRType::Load:
c@116 274 VampnProto::readRpcRequest_Load(rr.loadRequest, reader);
c@116 275 break;
c@75 276 case RRType::Configure:
c@116 277 VampnProto::readRpcRequest_Configure(rr.configurationRequest,
c@116 278 reader, mapper);
c@116 279 break;
c@75 280 case RRType::Process:
c@116 281 VampnProto::readRpcRequest_Process(rr.processRequest, reader, mapper);
c@116 282 break;
c@75 283 case RRType::Finish:
c@116 284 VampnProto::readRpcRequest_Finish(rr.finishRequest, reader, mapper);
c@116 285 break;
c@75 286 case RRType::NotValid:
c@116 287 break;
c@75 288 }
c@75 289
c@75 290 return rr;
c@75 291 }
c@75 292
c@75 293 void
c@75 294 writeResponseCapnp(RequestOrResponse &rr)
c@75 295 {
c@97 296 capnp::MallocMessageBuilder message;
c@97 297 piper::RpcResponse::Builder builder = message.initRoot<piper::RpcResponse>();
c@75 298
c@75 299 buildId(builder, rr.id);
c@75 300
c@75 301 if (!rr.success) {
c@75 302
c@116 303 VampnProto::buildRpcResponse_Error(builder, rr.errorText, rr.type);
c@75 304
c@75 305 } else {
c@116 306
c@116 307 switch (rr.type) {
c@75 308
c@116 309 case RRType::List:
c@116 310 VampnProto::buildRpcResponse_List(builder, rr.listResponse);
c@116 311 break;
c@116 312 case RRType::Load:
c@116 313 VampnProto::buildRpcResponse_Load(builder, rr.loadResponse, mapper);
c@116 314 break;
c@116 315 case RRType::Configure:
c@116 316 VampnProto::buildRpcResponse_Configure(builder, rr.configurationResponse, mapper);
c@116 317 break;
c@116 318 case RRType::Process:
c@116 319 VampnProto::buildRpcResponse_Process(builder, rr.processResponse, mapper);
c@116 320 break;
c@116 321 case RRType::Finish:
c@116 322 VampnProto::buildRpcResponse_Finish(builder, rr.finishResponse, mapper);
c@116 323 break;
c@116 324 case RRType::NotValid:
c@116 325 break;
c@116 326 }
c@75 327 }
c@75 328
c@75 329 writeMessageToFd(1, message);
c@75 330 }
c@75 331
c@75 332 void
c@75 333 writeExceptionCapnp(const std::exception &e, RRType type)
c@75 334 {
c@97 335 capnp::MallocMessageBuilder message;
c@97 336 piper::RpcResponse::Builder builder = message.initRoot<piper::RpcResponse>();
c@75 337 VampnProto::buildRpcResponse_Exception(builder, e, type);
c@75 338
c@75 339 writeMessageToFd(1, message);
c@75 340 }
c@75 341
c@75 342 RequestOrResponse
c@116 343 handleRequest(const RequestOrResponse &request, bool debug)
c@75 344 {
c@75 345 RequestOrResponse response;
c@75 346 response.direction = RequestOrResponse::Response;
c@75 347 response.type = request.type;
c@75 348
c@75 349 switch (request.type) {
c@75 350
c@75 351 case RRType::List:
c@116 352 response.listResponse = LoaderRequests().listPluginData();
c@116 353 response.success = true;
c@116 354 break;
c@75 355
c@75 356 case RRType::Load:
c@116 357 response.loadResponse = LoaderRequests().loadPlugin(request.loadRequest);
c@116 358 if (response.loadResponse.plugin != nullptr) {
c@116 359 mapper.addPlugin(response.loadResponse.plugin);
c@116 360 if (debug) {
c@116 361 cerr << "piper-vamp-server " << pid << ": loaded plugin, handle = " << mapper.pluginToHandle(response.loadResponse.plugin) << endl;
c@116 362 }
c@116 363 response.success = true;
c@116 364 }
c@116 365 break;
c@116 366
c@75 367 case RRType::Configure:
c@75 368 {
c@116 369 auto &creq = request.configurationRequest;
c@116 370 auto h = mapper.pluginToHandle(creq.plugin);
c@116 371 if (mapper.isConfigured(h)) {
c@116 372 throw runtime_error("plugin has already been configured");
c@116 373 }
c@75 374
c@116 375 response.configurationResponse = LoaderRequests().configurePlugin(creq);
c@116 376
c@116 377 if (!response.configurationResponse.outputs.empty()) {
c@116 378 mapper.markConfigured
c@116 379 (h, creq.configuration.channelCount, creq.configuration.blockSize);
c@116 380 response.success = true;
c@116 381 }
c@116 382 break;
c@75 383 }
c@75 384
c@75 385 case RRType::Process:
c@75 386 {
c@116 387 auto &preq = request.processRequest;
c@116 388 auto h = mapper.pluginToHandle(preq.plugin);
c@116 389 if (!mapper.isConfigured(h)) {
c@116 390 throw runtime_error("plugin has not been configured");
c@116 391 }
c@75 392
c@116 393 int channels = int(preq.inputBuffers.size());
c@116 394 if (channels != mapper.getChannelCount(h)) {
c@116 395 throw runtime_error("wrong number of channels supplied to process");
c@116 396 }
c@116 397
c@116 398 const float **fbuffers = new const float *[channels];
c@116 399 for (int i = 0; i < channels; ++i) {
c@116 400 if (int(preq.inputBuffers[i].size()) != mapper.getBlockSize(h)) {
c@116 401 delete[] fbuffers;
c@116 402 throw runtime_error("wrong block size supplied to process");
c@116 403 }
c@116 404 fbuffers[i] = preq.inputBuffers[i].data();
c@116 405 }
c@75 406
c@116 407 response.processResponse.plugin = preq.plugin;
c@116 408 response.processResponse.features =
c@116 409 preq.plugin->process(fbuffers, preq.timestamp);
c@116 410 response.success = true;
c@75 411
c@116 412 delete[] fbuffers;
c@116 413 break;
c@75 414 }
c@75 415
c@75 416 case RRType::Finish:
c@75 417 {
c@116 418 auto &freq = request.finishRequest;
c@116 419 response.finishResponse.plugin = freq.plugin;
c@77 420
c@116 421 auto h = mapper.pluginToHandle(freq.plugin);
c@77 422 // Finish can be called (to unload the plugin) even if the
c@77 423 // plugin has never been configured or used. But we want to
c@77 424 // make sure we call getRemainingFeatures only if we have
c@77 425 // actually configured the plugin.
c@116 426 if (mapper.isConfigured(h)) {
c@77 427 response.finishResponse.features = freq.plugin->getRemainingFeatures();
c@116 428 }
c@75 429
c@116 430 // We do not delete the plugin here -- we need it in the
c@116 431 // mapper when converting the features. It gets deleted in the
c@116 432 // calling function.
c@116 433 response.success = true;
c@116 434 break;
c@75 435 }
c@75 436
c@75 437 case RRType::NotValid:
c@116 438 break;
c@75 439 }
c@75 440
c@75 441 return response;
c@75 442 }
c@75 443
c@116 444 RequestOrResponse
c@116 445 readRequest(string format)
c@75 446 {
c@116 447 if (format == "capnp") {
c@116 448 return readRequestCapnp();
c@116 449 } else if (format == "json") {
c@116 450 string err;
c@116 451 auto result = readRequestJson(err);
c@116 452 if (err != "") throw runtime_error(err);
c@116 453 else return result;
c@116 454 } else {
c@116 455 throw runtime_error("unknown input format \"" + format + "\"");
c@116 456 }
c@116 457 }
c@116 458
c@116 459 void
c@116 460 writeResponse(string format, RequestOrResponse &rr)
c@116 461 {
c@116 462 if (format == "capnp") {
c@116 463 writeResponseCapnp(rr);
c@116 464 } else if (format == "json") {
c@116 465 writeResponseJson(rr, false);
c@116 466 } else {
c@116 467 throw runtime_error("unknown output format \"" + format + "\"");
c@116 468 }
c@116 469 }
c@116 470
c@116 471 void
c@116 472 writeException(string format, const std::exception &e, RRType type)
c@116 473 {
c@116 474 if (format == "capnp") {
c@116 475 writeExceptionCapnp(e, type);
c@116 476 } else if (format == "json") {
c@116 477 writeExceptionJson(e, type);
c@116 478 } else {
c@116 479 throw runtime_error("unknown output format \"" + format + "\"");
c@116 480 }
c@116 481 }
c@116 482
c@116 483 int main(int argc, char **argv)
c@116 484 {
c@116 485 if (argc != 2 && argc != 3) {
c@116 486 usage();
c@75 487 }
c@75 488
c@116 489 bool debug = false;
c@112 490
c@116 491 string arg = argv[1];
c@116 492 if (arg == "-h") {
c@116 493 if (argc == 2) {
c@116 494 usage(true);
c@116 495 } else {
c@116 496 usage();
c@116 497 }
c@116 498 } else if (arg == "-v") {
c@116 499 if (argc == 2) {
c@116 500 version();
c@116 501 } else {
c@116 502 usage();
c@116 503 }
c@116 504 } else if (arg == "-d") {
c@116 505 if (argc == 2) {
c@116 506 usage();
c@116 507 } else {
c@116 508 debug = true;
c@116 509 arg = argv[2];
c@116 510 }
c@116 511 }
c@116 512
c@116 513 string format = arg;
c@116 514
c@116 515 if (format != "capnp" && format != "json") {
c@116 516 usage();
c@116 517 }
c@116 518
c@116 519 if (debug) {
c@116 520 cerr << myname << " " << pid << ": waiting for format: " << format << endl;
c@116 521 }
c@116 522
c@75 523 while (true) {
c@75 524
c@116 525 RequestOrResponse request;
c@116 526
c@116 527 try {
c@75 528
c@116 529 request = readRequest(format);
c@116 530
c@116 531 // NotValid without an exception indicates EOF:
c@116 532 if (request.type == RRType::NotValid) {
c@116 533 if (debug) {
c@116 534 cerr << myname << " " << pid << ": eof reached, exiting" << endl;
c@116 535 }
c@116 536 break;
c@116 537 }
c@75 538
c@116 539 if (debug) {
c@116 540 cerr << myname << " " << pid << ": request received, of type "
c@116 541 << int(request.type)
c@116 542 << endl;
c@116 543 }
c@75 544
c@116 545 RequestOrResponse response = handleRequest(request, debug);
c@75 546 response.id = request.id;
c@75 547
c@116 548 if (debug) {
c@116 549 cerr << myname << " " << pid << ": request handled, writing response"
c@116 550 << endl;
c@116 551 }
c@116 552
c@116 553 writeResponse(format, response);
c@75 554
c@116 555 if (debug) {
c@116 556 cerr << myname << " " << pid << ": response written" << endl;
c@116 557 }
c@75 558
c@116 559 if (request.type == RRType::Finish) {
c@116 560 auto h = mapper.pluginToHandle(request.finishRequest.plugin);
c@116 561 if (debug) {
c@116 562 cerr << myname << " " << pid << ": deleting the plugin with handle " << h << endl;
c@116 563 }
c@116 564 mapper.removePlugin(h);
c@116 565 delete request.finishRequest.plugin;
c@116 566 }
c@116 567
c@116 568 } catch (std::exception &e) {
c@75 569
c@116 570 if (debug) {
c@116 571 cerr << myname << " " << pid << ": error: " << e.what() << endl;
c@116 572 }
c@75 573
c@116 574 writeException(format, e, request.type);
c@116 575
c@116 576 exit(1);
c@116 577 }
c@75 578 }
c@75 579
c@75 580 exit(0);
c@75 581 }