Mercurial > hg > piper-cpp
comparison vamp-client/PiperCapnpClient.h @ 90:6429a99abcad
Split out classes
author | Chris Cannam <c.cannam@qmul.ac.uk> |
---|---|
date | Thu, 13 Oct 2016 10:17:59 +0100 |
parents | |
children | c897c9a8daf1 |
comparison
equal
deleted
inserted
replaced
89:03ed2e0a6c8f | 90:6429a99abcad |
---|---|
1 | |
2 #ifndef PIPER_CAPNP_CLIENT_H | |
3 #define PIPER_CAPNP_CLIENT_H | |
4 | |
5 #include "PiperClient.h" | |
6 #include "SynchronousTransport.h" | |
7 | |
8 #include "vamp-support/AssignedPluginHandleMapper.h" | |
9 #include "vamp-capnp/VampnProto.h" | |
10 | |
11 namespace piper { //!!! change | |
12 | |
13 class PiperCapnpClient : public PiperStubPluginClientInterface | |
14 { | |
15 // unsigned to avoid undefined behaviour on possible wrap | |
16 typedef uint32_t ReqId; | |
17 | |
18 public: | |
19 PiperCapnpClient(SynchronousTransport *transport) : //!!! ownership? shared ptr? | |
20 m_transport(transport) { | |
21 } | |
22 | |
23 ~PiperCapnpClient() { | |
24 } | |
25 | |
26 //!!! obviously, factor out all repetitive guff | |
27 | |
28 //!!! list and load are supposed to be called by application code, | |
29 //!!! but the rest are only supposed to be called by the plugin -- | |
30 //!!! sort out the api here | |
31 | |
32 Vamp::Plugin * | |
33 load(std::string key, float inputSampleRate, int adapterFlags) { | |
34 | |
35 if (!m_transport->isOK()) { | |
36 throw std::runtime_error("Piper server failed to start"); | |
37 } | |
38 | |
39 Vamp::HostExt::LoadRequest request; | |
40 request.pluginKey = key; | |
41 request.inputSampleRate = inputSampleRate; | |
42 request.adapterFlags = adapterFlags; | |
43 | |
44 capnp::MallocMessageBuilder message; | |
45 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
46 | |
47 VampnProto::buildRpcRequest_Load(builder, request); | |
48 ReqId id = getId(); | |
49 builder.getId().setNumber(id); | |
50 | |
51 auto arr = messageToFlatArray(message); | |
52 | |
53 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
54 arr.asChars().size()); | |
55 | |
56 //!!! ... --> will also need some way to kill this process | |
57 //!!! (from another thread) | |
58 | |
59 auto karr = toKJArray(responseBuffer); | |
60 capnp::FlatArrayMessageReader responseMessage(karr); | |
61 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
62 | |
63 //!!! handle (explicit) error case | |
64 | |
65 checkResponseType(reader, RpcResponse::Response::Which::LOAD, id); | |
66 | |
67 const LoadResponse::Reader &lr = reader.getResponse().getLoad(); | |
68 | |
69 Vamp::HostExt::PluginStaticData psd; | |
70 Vamp::HostExt::PluginConfiguration defaultConfig; | |
71 VampnProto::readExtractorStaticData(psd, lr.getStaticData()); | |
72 VampnProto::readConfiguration(defaultConfig, lr.getDefaultConfiguration()); | |
73 | |
74 Vamp::Plugin *plugin = new PiperStubPlugin(this, | |
75 inputSampleRate, | |
76 psd, | |
77 defaultConfig); | |
78 | |
79 m_mapper.addPlugin(lr.getHandle(), plugin); | |
80 | |
81 return plugin; | |
82 }; | |
83 | |
84 protected: | |
85 virtual | |
86 Vamp::Plugin::OutputList | |
87 configure(PiperStubPlugin *plugin, | |
88 Vamp::HostExt::PluginConfiguration config) override { | |
89 | |
90 if (!m_transport->isOK()) { | |
91 throw std::runtime_error("Piper server failed to start"); | |
92 } | |
93 | |
94 Vamp::HostExt::ConfigurationRequest request; | |
95 request.plugin = plugin; | |
96 request.configuration = config; | |
97 | |
98 capnp::MallocMessageBuilder message; | |
99 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
100 | |
101 VampnProto::buildRpcRequest_Configure(builder, request, m_mapper); | |
102 ReqId id = getId(); | |
103 builder.getId().setNumber(id); | |
104 | |
105 auto arr = messageToFlatArray(message); | |
106 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
107 arr.asChars().size()); | |
108 auto karr = toKJArray(responseBuffer); | |
109 capnp::FlatArrayMessageReader responseMessage(karr); | |
110 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
111 | |
112 //!!! handle (explicit) error case | |
113 | |
114 checkResponseType(reader, RpcResponse::Response::Which::CONFIGURE, id); | |
115 | |
116 Vamp::HostExt::ConfigurationResponse cr; | |
117 VampnProto::readConfigurationResponse(cr, | |
118 reader.getResponse().getConfigure(), | |
119 m_mapper); | |
120 | |
121 return cr.outputs; | |
122 }; | |
123 | |
124 virtual | |
125 Vamp::Plugin::FeatureSet | |
126 process(PiperStubPlugin *plugin, | |
127 std::vector<std::vector<float> > inputBuffers, | |
128 Vamp::RealTime timestamp) override { | |
129 | |
130 if (!m_transport->isOK()) { | |
131 throw std::runtime_error("Piper server failed to start"); | |
132 } | |
133 | |
134 Vamp::HostExt::ProcessRequest request; | |
135 request.plugin = plugin; | |
136 request.inputBuffers = inputBuffers; | |
137 request.timestamp = timestamp; | |
138 | |
139 capnp::MallocMessageBuilder message; | |
140 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
141 | |
142 VampnProto::buildRpcRequest_Process(builder, request, m_mapper); | |
143 ReqId id = getId(); | |
144 builder.getId().setNumber(id); | |
145 | |
146 auto arr = messageToFlatArray(message); | |
147 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
148 arr.asChars().size()); | |
149 auto karr = toKJArray(responseBuffer); | |
150 capnp::FlatArrayMessageReader responseMessage(karr); | |
151 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
152 | |
153 //!!! handle (explicit) error case | |
154 | |
155 checkResponseType(reader, RpcResponse::Response::Which::PROCESS, id); | |
156 | |
157 Vamp::HostExt::ProcessResponse pr; | |
158 VampnProto::readProcessResponse(pr, | |
159 reader.getResponse().getProcess(), | |
160 m_mapper); | |
161 | |
162 return pr.features; | |
163 } | |
164 | |
165 virtual Vamp::Plugin::FeatureSet | |
166 finish(PiperStubPlugin *plugin) override { | |
167 | |
168 if (!m_transport->isOK()) { | |
169 throw std::runtime_error("Piper server failed to start"); | |
170 } | |
171 | |
172 Vamp::HostExt::FinishRequest request; | |
173 request.plugin = plugin; | |
174 | |
175 capnp::MallocMessageBuilder message; | |
176 RpcRequest::Builder builder = message.initRoot<RpcRequest>(); | |
177 | |
178 VampnProto::buildRpcRequest_Finish(builder, request, m_mapper); | |
179 ReqId id = getId(); | |
180 builder.getId().setNumber(id); | |
181 | |
182 auto arr = messageToFlatArray(message); | |
183 auto responseBuffer = m_transport->call(arr.asChars().begin(), | |
184 arr.asChars().size()); | |
185 auto karr = toKJArray(responseBuffer); | |
186 capnp::FlatArrayMessageReader responseMessage(karr); | |
187 RpcResponse::Reader reader = responseMessage.getRoot<RpcResponse>(); | |
188 | |
189 //!!! handle (explicit) error case | |
190 | |
191 checkResponseType(reader, RpcResponse::Response::Which::FINISH, id); | |
192 | |
193 Vamp::HostExt::ProcessResponse pr; | |
194 VampnProto::readFinishResponse(pr, | |
195 reader.getResponse().getFinish(), | |
196 m_mapper); | |
197 | |
198 m_mapper.removePlugin(m_mapper.pluginToHandle(plugin)); | |
199 | |
200 // Don't delete the plugin. It's the plugin that is supposed | |
201 // to be calling us here | |
202 | |
203 return pr.features; | |
204 } | |
205 | |
206 private: | |
207 AssignedPluginHandleMapper m_mapper; | |
208 ReqId getId() { | |
209 //!!! todo: mutex | |
210 static ReqId m_nextId = 0; | |
211 return m_nextId++; | |
212 } | |
213 | |
214 kj::Array<capnp::word> | |
215 toKJArray(const std::vector<char> &buffer) { | |
216 // We could do this whole thing with fewer copies, but let's | |
217 // see whether it matters first | |
218 size_t wordSize = sizeof(capnp::word); | |
219 size_t words = buffer.size() / wordSize; | |
220 kj::Array<capnp::word> karr(kj::heapArray<capnp::word>(words)); | |
221 memcpy(karr.begin(), buffer.data(), words * wordSize); | |
222 return karr; | |
223 } | |
224 | |
225 void | |
226 checkResponseType(const RpcResponse::Reader &r, | |
227 RpcResponse::Response::Which type, | |
228 ReqId id) { | |
229 | |
230 if (r.getResponse().which() != type) { | |
231 throw std::runtime_error("Wrong response type"); | |
232 } | |
233 if (ReqId(r.getId().getNumber()) != id) { | |
234 throw std::runtime_error("Wrong response id"); | |
235 } | |
236 } | |
237 | |
238 private: | |
239 SynchronousTransport *m_transport; //!!! I don't own this, but should I? | |
240 }; | |
241 | |
242 } | |
243 | |
244 #endif |