CCF
Loading...
Searching...
No Matches
forwarder.h
Go to the documentation of this file.
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the Apache 2.0 License.
3#pragma once
4
7#include "enclave/rpc_map.h"
9#include "kv/kv_types.h"
10#include "node/node_to_node.h"
11
12namespace ccf
13{
14 class RpcContextImpl;
15
17 {
18 public:
20
21 virtual void process_forwarded(
22 std::shared_ptr<ccf::RpcContextImpl> fwd_ctx) = 0;
23 };
24
25 template <typename ChannelProxy>
27 {
28 private:
29 std::weak_ptr<ccf::AbstractRPCResponder> rpcresponder;
30 std::shared_ptr<ChannelProxy> n2n_channels;
31 std::weak_ptr<ccf::RPCMap> rpc_map;
32 NodeId self;
33
34 using ForwardedCommandId = ForwardedHeader_v2::ForwardedCommandId;
35 ForwardedCommandId next_command_id = 0;
36
37 struct TimeoutTask
38 {
40 uint16_t thread_id;
41 };
42
43 std::unordered_map<ForwardedCommandId, TimeoutTask> timeout_tasks;
44 ccf::pal::Mutex timeout_tasks_lock;
45
46 using IsCallerCertForwarded = bool;
47
48 struct SendTimeoutErrorMsg
49 {
50 SendTimeoutErrorMsg(
51 Forwarder<ChannelProxy>* forwarder_,
52 const ccf::NodeId& to_,
53 size_t client_session_id_,
54 const std::chrono::milliseconds& timeout_) :
55 forwarder(forwarder_),
56 to(to_),
57 client_session_id(client_session_id_),
58 timeout(timeout_)
59 {}
60
61 Forwarder<ChannelProxy>* forwarder;
62 ccf::NodeId to;
63 size_t client_session_id;
64 std::chrono::milliseconds timeout;
65 };
66
67 struct CancelTimerMsg
68 {
70 };
71
72 std::unique_ptr<::threading::Tmsg<SendTimeoutErrorMsg>>
73 create_timeout_error_task(
74 const ccf::NodeId& to,
75 size_t client_session_id,
76 const std::chrono::milliseconds& timeout)
77 {
78 return std::make_unique<::threading::Tmsg<SendTimeoutErrorMsg>>(
79 [](std::unique_ptr<::threading::Tmsg<SendTimeoutErrorMsg>> msg) {
80 msg->data.forwarder->send_timeout_error_response(
81 msg->data.to, msg->data.client_session_id, msg->data.timeout);
82 },
83 this,
84 to,
85 client_session_id,
86 timeout);
87 }
88
89 void send_timeout_error_response(
90 NodeId to,
91 size_t client_session_id,
92 const std::chrono::milliseconds& timeout)
93 {
94 auto rpc_responder_shared = rpcresponder.lock();
95 if (rpc_responder_shared)
96 {
97 auto response = ::http::Response(HTTP_STATUS_GATEWAY_TIMEOUT);
98 auto body = fmt::format(
99 "Request was forwarded to node {}, but no response was received "
100 "after {}ms",
101 to,
102 timeout.count());
103 response.set_body(body);
104 response.set_header(
105 http::headers::CONTENT_TYPE, http::headervalues::contenttype::TEXT);
106 rpc_responder_shared->reply_async(
107 client_session_id, false, response.build_response());
108 }
109 }
110
111 static void cancel_forwarding_task_cb(
112 std::unique_ptr<::threading::Tmsg<CancelTimerMsg>> msg)
113 {
114 cancel_forwarding_task(msg->data.timer_entry);
115 }
116
117 static void cancel_forwarding_task(
119 {
121 }
122
123 public:
125 std::weak_ptr<ccf::AbstractRPCResponder> rpcresponder,
126 std::shared_ptr<ChannelProxy> n2n_channels,
127 std::weak_ptr<ccf::RPCMap> rpc_map_) :
128 rpcresponder(rpcresponder),
129 n2n_channels(n2n_channels),
130 rpc_map(rpc_map_)
131 {}
132
133 void initialize(const NodeId& self_)
134 {
135 self = self_;
136 }
137
139 std::shared_ptr<ccf::RpcContextImpl> rpc_ctx,
140 const NodeId& to,
141 const std::vector<uint8_t>& caller_cert,
142 const std::chrono::milliseconds& timeout) override
143 {
144 auto session_ctx = rpc_ctx->get_session_context();
145
146 IsCallerCertForwarded include_caller = false;
147 const auto method = rpc_ctx->get_method();
148 const auto& raw_request = rpc_ctx->get_serialised_request();
149 auto client_session_id = session_ctx->client_session_id;
150 size_t size = sizeof(client_session_id) + sizeof(IsCallerCertForwarded) +
151 raw_request.size();
152 if (!caller_cert.empty())
153 {
154 size += sizeof(size_t) + caller_cert.size();
155 include_caller = true;
156 }
157
158 std::vector<uint8_t> plain(size);
159 auto data_ = plain.data();
160 auto size_ = plain.size();
161 serialized::write(data_, size_, client_session_id);
162 serialized::write(data_, size_, include_caller);
163 if (include_caller)
164 {
165 serialized::write(data_, size_, caller_cert.size());
166 serialized::write(data_, size_, caller_cert.data(), caller_cert.size());
167 }
168 serialized::write(data_, size_, raw_request.data(), raw_request.size());
169
170 ForwardedCommandId command_id;
171 {
172 std::lock_guard<ccf::pal::Mutex> guard(timeout_tasks_lock);
173 command_id = next_command_id++;
174 timeout_tasks[command_id] = {
176 create_timeout_error_task(to, client_session_id, timeout), timeout),
178 }
179
180 const auto view_opt = session_ctx->active_view;
181 if (!view_opt.has_value())
182 {
183 throw std::logic_error(
184 "Expected active_view to be set before forwarding");
185 }
186 ForwardedCommandHeader_v3 header(command_id, view_opt.value());
187
188 return n2n_channels->send_encrypted(
189 to, NodeMsgType::forwarded_msg, plain, header);
190 }
191
192 template <typename TFwdHdr>
193 std::shared_ptr<::http::HttpRpcContext> recv_forwarded_command(
194 const NodeId& from, const uint8_t* data, size_t size)
195 {
196 std::pair<TFwdHdr, std::vector<uint8_t>> r;
197 try
198 {
199 LOG_TRACE_FMT("Receiving forwarded command of {} bytes", size);
200 LOG_TRACE_FMT(" => {:02x}", fmt::join(data, data + size, ""));
201
202 r = n2n_channels->template recv_encrypted<TFwdHdr>(from, data, size);
203 }
204 catch (const std::logic_error& err)
205 {
206 LOG_FAIL_FMT("Invalid forwarded command");
207 LOG_DEBUG_FMT("Invalid forwarded command: {}", err.what());
208 return nullptr;
209 }
210
211 std::vector<uint8_t> caller_cert;
212 const auto& plain_ = r.second;
213 auto data_ = plain_.data();
214 auto size_ = plain_.size();
215 auto client_session_id = serialized::read<size_t>(data_, size_);
216 auto includes_caller =
217 serialized::read<IsCallerCertForwarded>(data_, size_);
218 if (includes_caller)
219 {
220 auto caller_size = serialized::read<size_t>(data_, size_);
221 caller_cert = serialized::read(data_, size_, caller_size);
222 }
223 std::vector<uint8_t> raw_request = serialized::read(data_, size_, size_);
224
225 auto session =
226 std::make_shared<ccf::SessionContext>(client_session_id, caller_cert);
227 session->is_forwarded = true;
228
229 if constexpr (std::is_same_v<TFwdHdr, ForwardedCommandHeader_v3>)
230 {
231 ccf::View view = r.first.active_view;
232 session->active_view = view;
233 }
234
235 try
236 {
238 session, raw_request, r.first.frame_format);
239 }
240 catch (const ::http::RequestTooLargeException& rexc)
241 {
242 LOG_FAIL_FMT("Forwarded request exceeded limit: {}", rexc.what());
243 return nullptr;
244 }
245 catch (const std::exception& err)
246 {
247 LOG_FAIL_FMT("Invalid forwarded request");
248 LOG_DEBUG_FMT("Invalid forwarded request: {}", err.what());
249 return nullptr;
250 }
251 }
252
253 template <typename TFwdHdr>
255 size_t client_session_id,
256 const NodeId& from_node,
257 const TFwdHdr& header,
258 const std::vector<uint8_t>& data)
259 {
260 std::vector<uint8_t> plain(sizeof(client_session_id) + data.size());
261 auto data_ = plain.data();
262 auto size_ = plain.size();
263 serialized::write(data_, size_, client_session_id);
264 serialized::write(data_, size_, data.data(), data.size());
265
266 if (!n2n_channels->send_encrypted(
267 from_node, NodeMsgType::forwarded_msg, plain, header))
268 {
269 LOG_FAIL_FMT("Failed to send forwarded response to {}", from_node);
270 }
271 }
272
274 {
276 std::vector<uint8_t> response_body;
278 };
279
280 template <typename TFwdHdr>
281 std::optional<ForwardedResponseResult> recv_forwarded_response(
282 const NodeId& from, const uint8_t* data, size_t size)
283 {
284 std::pair<TFwdHdr, std::vector<uint8_t>> r;
285 try
286 {
287 LOG_TRACE_FMT("Receiving response of {} bytes", size);
288 LOG_TRACE_FMT(" => {:02x}", fmt::join(data, data + size, ""));
289
290 r = n2n_channels->template recv_encrypted<TFwdHdr>(from, data, size);
291 }
292 catch (const std::logic_error& err)
293 {
294 LOG_FAIL_FMT("Invalid forwarded response");
295 LOG_DEBUG_FMT("Invalid forwarded response: {}", err.what());
296 return std::nullopt;
297 }
298
300 if constexpr (std::is_same_v<TFwdHdr, ForwardedResponseHeader_v3>)
301 {
302 ret.should_terminate_session = r.first.terminate_session;
303 }
304
305 const auto& plain_ = r.second;
306 auto data_ = plain_.data();
307 auto size_ = plain_.size();
308 ret.client_session_id = serialized::read<size_t>(data_, size_);
309 ret.response_body = serialized::read(data_, size_, size_);
310
311 return ret;
312 }
313
314 std::shared_ptr<ForwardedRpcHandler> get_forwarder_handler(
315 std::shared_ptr<::http::HttpRpcContext>& ctx)
316 {
317 if (ctx == nullptr)
318 {
319 LOG_FAIL_FMT("Failed to receive forwarded command");
320 return nullptr;
321 }
322
323 std::shared_ptr<ccf::RPCMap> rpc_map_shared = rpc_map.lock();
324 if (rpc_map_shared == nullptr)
325 {
326 LOG_FAIL_FMT("Failed to obtain RPCMap");
327 return nullptr;
328 }
329
330 std::shared_ptr<ccf::RpcHandler> search =
331 ::http::fetch_rpc_handler(ctx, rpc_map_shared);
332
333 auto fwd_handler = std::dynamic_pointer_cast<ForwardedRpcHandler>(search);
334 if (!fwd_handler)
335 {
337 "Failed to process forwarded command: handler is not a "
338 "ForwardedRpcHandler");
339 return nullptr;
340 }
341
342 return fwd_handler;
343 }
344
345 void recv_message(const ccf::NodeId& from, const uint8_t* data, size_t size)
346 {
347 try
348 {
349 const auto forwarded_msg = serialized::peek<ForwardedMsg>(data, size);
351 "recv_message({}, {} bytes) (type={})",
352 from,
353 size,
354 (size_t)forwarded_msg);
355
356 switch (forwarded_msg)
357 {
359 {
360 auto ctx =
361 recv_forwarded_command<ForwardedHeader_v1>(from, data, size);
362
363 auto fwd_handler = get_forwarder_handler(ctx);
364 if (fwd_handler == nullptr)
365 {
366 return;
367 }
368
369 // frame_format is deliberately unset, the forwarder ignores it
370 // and expects the same format they forwarded.
371 ForwardedHeader_v1 response_header{
373
374 LOG_DEBUG_FMT("Sending forwarded response to {}", from);
375 fwd_handler->process_forwarded(ctx);
376
378 ctx->get_session_context()->client_session_id,
379 from,
380 response_header,
381 ctx->serialise_response());
382 break;
383 }
384
386 {
387 auto ctx =
388 recv_forwarded_command<ForwardedHeader_v2>(from, data, size);
389
390 auto fwd_handler = get_forwarder_handler(ctx);
391 if (fwd_handler == nullptr)
392 {
393 return;
394 }
395
396 const auto forwarded_hdr_v2 =
397 serialized::peek<ForwardedHeader_v2>(data, size);
398 const auto cmd_id = forwarded_hdr_v2.id;
399
400 fwd_handler->process_forwarded(ctx);
401
402 // frame_format is deliberately unset, the forwarder ignores it
403 // and expects the same format they forwarded.
404 ForwardedHeader_v2 response_header{
406
407 LOG_DEBUG_FMT("Sending forwarded response to {}", from);
408
410 ctx->get_session_context()->client_session_id,
411 from,
412 response_header,
413 ctx->serialise_response());
414 break;
415 }
416
418 {
419 auto ctx = recv_forwarded_command<ForwardedCommandHeader_v3>(
420 from, data, size);
421
422 auto fwd_handler = get_forwarder_handler(ctx);
423 if (fwd_handler == nullptr)
424 {
425 return;
426 }
427
428 const auto forwarded_hdr_v3 =
429 serialized::peek<ForwardedCommandHeader_v3>(data, size);
430 const auto cmd_id = forwarded_hdr_v3.id;
431
432 fwd_handler->process_forwarded(ctx);
433
434 // frame_format is deliberately unset, the forwarder ignores it
435 // and expects the same format they forwarded.
436 ForwardedResponseHeader_v3 response_header(
437 cmd_id, ctx->terminate_session);
438
439 LOG_DEBUG_FMT("Sending forwarded response to {}", from);
440
442 ctx->get_session_context()->client_session_id,
443 from,
444 response_header,
445 ctx->serialise_response());
446 break;
447 }
448
451 {
452 const auto forwarded_hdr_v2 =
453 serialized::peek<ForwardedHeader_v2>(data, size);
454 const auto cmd_id = forwarded_hdr_v2.id;
455
456 // Cancel and delete the corresponding timeout task, so it will no
457 // longer trigger a timeout error
458 std::lock_guard<ccf::pal::Mutex> guard(timeout_tasks_lock);
459 auto it = timeout_tasks.find(cmd_id);
460 if (it != timeout_tasks.end())
461 {
462 if (
463 ccf::threading::get_current_thread_id() != it->second.thread_id)
464 {
465 auto msg = std::make_unique<::threading::Tmsg<CancelTimerMsg>>(
466 &cancel_forwarding_task_cb);
467 msg->data.timer_entry = it->second.timer_entry;
468
470 it->second.thread_id, std::move(msg));
471 }
472 else
473 {
474 cancel_forwarding_task(it->second.timer_entry);
475 }
476 it = timeout_tasks.erase(it);
477 }
478 else
479 {
481 "Response for {} received too late - already sent timeout "
482 "error to client",
483 cmd_id);
484 return;
485 }
486 // Deliberate fall-through
487 }
488
490 {
491 std::optional<ForwardedResponseResult> rep;
493 {
494 rep = recv_forwarded_response<ForwardedResponseHeader_v3>(
495 from, data, size);
496 }
498 {
499 rep =
500 recv_forwarded_response<ForwardedHeader_v2>(from, data, size);
501 }
502 else
503 {
504 rep =
505 recv_forwarded_response<ForwardedHeader_v1>(from, data, size);
506 }
507
508 if (!rep.has_value())
509 {
510 return;
511 }
512
514 "Sending forwarded response to RPC endpoint {}",
515 rep->client_session_id);
516
517 auto rpc_responder_shared = rpcresponder.lock();
518 if (
519 rpc_responder_shared &&
520 !rpc_responder_shared->reply_async(
521 rep->client_session_id,
522 rep->should_terminate_session,
523 std::move(rep->response_body)))
524 {
525 return;
526 }
527
528 break;
529 }
530
531 default:
532 {
533 LOG_FAIL_FMT("Unknown frontend msg type: {}", forwarded_msg);
534 break;
535 }
536 }
537 }
538 catch (const std::exception& e)
539 {
540 LOG_FAIL_FMT("Exception in {}", __PRETTY_FUNCTION__);
541 LOG_DEBUG_FMT("Error: {}", e.what());
542 return;
543 }
544 }
545 };
546}
Definition forwarder_types.h:22
Definition forwarder.h:17
virtual ~ForwardedRpcHandler()
Definition forwarder.h:19
virtual void process_forwarded(std::shared_ptr< ccf::RpcContextImpl > fwd_ctx)=0
Definition forwarder.h:27
void initialize(const NodeId &self_)
Definition forwarder.h:133
std::shared_ptr<::http::HttpRpcContext > recv_forwarded_command(const NodeId &from, const uint8_t *data, size_t size)
Definition forwarder.h:193
void recv_message(const ccf::NodeId &from, const uint8_t *data, size_t size)
Definition forwarder.h:345
Forwarder(std::weak_ptr< ccf::AbstractRPCResponder > rpcresponder, std::shared_ptr< ChannelProxy > n2n_channels, std::weak_ptr< ccf::RPCMap > rpc_map_)
Definition forwarder.h:124
void send_forwarded_response(size_t client_session_id, const NodeId &from_node, const TFwdHdr &header, const std::vector< uint8_t > &data)
Definition forwarder.h:254
std::shared_ptr< ForwardedRpcHandler > get_forwarder_handler(std::shared_ptr<::http::HttpRpcContext > &ctx)
Definition forwarder.h:314
std::optional< ForwardedResponseResult > recv_forwarded_response(const NodeId &from, const uint8_t *data, size_t size)
Definition forwarder.h:281
bool forward_command(std::shared_ptr< ccf::RpcContextImpl > rpc_ctx, const NodeId &to, const std::vector< uint8_t > &caller_cert, const std::chrono::milliseconds &timeout) override
Definition forwarder.h:138
Definition http_builder.h:200
static ThreadMessaging & instance()
Definition thread_messaging.h:283
bool cancel_timer_task(TaskQueue::TimerEntry timer_entry)
Definition thread_messaging.h:333
TaskQueue::TimerEntry add_task_after(std::unique_ptr< Tmsg< Payload > > msg, std::chrono::milliseconds ms)
Definition thread_messaging.h:326
void add_task(uint16_t tid, std::unique_ptr< Tmsg< Payload > > msg)
Definition thread_messaging.h:318
#define LOG_TRACE_FMT
Definition logger.h:356
#define LOG_DEBUG_FMT
Definition logger.h:357
#define LOG_FAIL_FMT
Definition logger.h:363
std::mutex Mutex
Definition locking.h:12
uint16_t get_current_thread_id()
Definition thread_local.cpp:15
Definition app_interface.h:14
std::shared_ptr<::http::HttpRpcContext > make_fwd_rpc_context(std::shared_ptr< ccf::SessionContext > s, const std::vector< uint8_t > &packed, ccf::FrameFormat frame_format)
Definition http_rpc_context.h:420
view
Definition signatures.h:54
@ forwarded_cmd_v3
Definition node_types.h:50
@ forwarded_cmd_v2
Definition node_types.h:44
@ forwarded_response_v1
Definition node_types.h:40
@ forwarded_response_v3
Definition node_types.h:51
@ forwarded_response_v2
Definition node_types.h:45
@ forwarded_cmd_v1
Definition node_types.h:39
uint64_t View
Definition tx_id.h:23
@ forwarded_msg
Definition node_types.h:22
void write(uint8_t *&data, size_t &size, const T &v)
Definition serialized.h:106
T read(const uint8_t *&data, size_t &size)
Definition serialized.h:59
Definition node_types.h:76
Definition node_types.h:64
Definition node_types.h:70
size_t ForwardedCommandId
Definition node_types.h:71
Definition node_types.h:92
Definition forwarder.h:274
size_t client_session_id
Definition forwarder.h:275
std::vector< uint8_t > response_body
Definition forwarder.h:276
bool should_terminate_session
Definition forwarder.h:277
Definition thread_messaging.h:85
Definition thread_messaging.h:27