CCF
Loading...
Searching...
No Matches
tls_session.h
Go to the documentation of this file.
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the Apache 2.0 License.
3#pragma once
4
5#include "ccf/ds/logger.h"
6#include "ds/messaging.h"
7#include "ds/ring_buffer.h"
9#include "tcp/msg_types.h"
10#include "tls/context.h"
11#include "tls/tls.h"
12
13#include <exception>
14
15namespace ccf
16{
26
27 class TLSSession : public std::enable_shared_from_this<TLSSession>
28 {
29 public:
30 using HandshakeErrorCB = std::function<void(std::string&&)>;
31
32 protected:
36
37 private:
38 std::vector<uint8_t> pending_write;
39 std::vector<uint8_t> pending_read;
40 // Decrypted data
41 std::vector<uint8_t> read_buffer;
42
43 std::unique_ptr<tls::Context> ctx;
44 SessionStatus status;
45
46 HandshakeErrorCB handshake_error_cb;
47
48 bool can_send()
49 {
50 // Closing endpoint should still be able to respond to clients (e.g. to
51 // report errors)
52 return status == ready || status == closing;
53 }
54
55 bool can_recv()
56 {
57 return status == ready || status == handshake;
58 }
59
60 struct SendRecvMsg
61 {
62 std::vector<uint8_t> data;
63 std::shared_ptr<TLSSession> self;
64 };
65
66 struct EmptyMsg
67 {
68 std::shared_ptr<TLSSession> self;
69 };
70
71 public:
73 int64_t session_id_,
74 ringbuffer::AbstractWriterFactory& writer_factory_,
75 std::unique_ptr<tls::Context> ctx_) :
76 to_host(writer_factory_.create_writer_to_outside()),
77 session_id(session_id_),
78 ctx(std::move(ctx_)),
79 status(handshake)
80 {
84 ctx->set_bio(this, send_callback_openssl, recv_callback_openssl);
85 }
86
87 virtual ~TLSSession()
88 {
89 RINGBUFFER_WRITE_MESSAGE(::tcp::tcp_closed, to_host, session_id);
90 }
91
93 {
94 return status;
95 }
96
97 void on_handshake_error(std::string&& error_msg)
98 {
99 if (handshake_error_cb)
100 {
101 handshake_error_cb(std::move(error_msg));
102 }
103 else
104 {
105 LOG_TRACE_FMT("{}", error_msg);
106 }
107 }
108
110 {
111 handshake_error_cb = std::move(cb);
112 }
113
114 std::string hostname()
115 {
116 if (status != ready)
117 {
118 return {};
119 }
120
121 return ctx->host();
122 }
123
124 std::vector<uint8_t> peer_cert()
125 {
126 return ctx->peer_cert();
127 }
128
129 // Returns count N of bytes read, which will be the first N bytes of data,
130 // up to a maximum of size. If exact is true, will only return either size
131 // or 0 (when size bytes are not currently available). data may be accessed
132 // beyond N during operation, up to size, but only the first N should be
133 // used by caller.
134 size_t read(uint8_t* data, size_t size, bool exact = false)
135 {
136 // This will return empty if the connection isn't
137 // ready, but it will not block on the handshake.
138 do_handshake();
139
140 if (status != ready)
141 {
142 LOG_TRACE_FMT("Not ready to read {} bytes", size);
143 return 0;
144 }
145
146 LOG_TRACE_FMT("Requesting up to {} bytes", size);
147
148 // Send pending writes.
149 flush();
150
151 size_t offset = 0;
152
153 if (read_buffer.size() > 0)
154 {
156 "Have existing read_buffer of size: {}", read_buffer.size());
157 offset = std::min(size, read_buffer.size());
158 ::memcpy(data, read_buffer.data(), offset);
159
160 if (offset < read_buffer.size())
161 read_buffer.erase(read_buffer.begin(), read_buffer.begin() + offset);
162 else
163 read_buffer.clear();
164
165 if (offset == size)
166 return size;
167
168 // NB: If we continue past here, read_buffer is empty
169 }
170
171 auto r = ctx->read(data + offset, size - offset);
172 LOG_TRACE_FMT("ctx->read returned: {}", r);
173
174 switch (r)
175 {
176 case 0:
178 {
180 "TLS {} close on read: {}", session_id, ::tls::error_string(r));
181
182 stop(closed);
183
184 if (!exact)
185 {
186 // Hit an error, but may still have some useful data from the
187 // previous read_buffer
188 return offset;
189 }
190
191 return 0;
192 }
193
196 {
197 if (!exact)
198 {
199 return offset;
200 }
201
202 // May have read something but not enough - copy it into read_buffer
203 // for next call
204 read_buffer.insert(read_buffer.end(), data, data + offset);
205 return 0;
206 }
207
208 default:
209 {
210 }
211 }
212
213 if (r < 0)
214 {
216 "TLS {} error on read: {}", session_id, ::tls::error_string(r));
217 stop(error);
218 return 0;
219 }
220
221 auto total = r + offset;
222
223 // We read _some_ data but not enough, and didn't get
224 // TLS_ERR_WANT_READ. Probably hit an internal size limit - try
225 // again
226 if (exact && (total < size))
227 {
229 "Asked for exactly {}, received {}, retrying", size, total);
230 read_buffer.insert(read_buffer.end(), data, data + total);
231 return read(data, size, exact);
232 }
233
234 return total;
235 }
236
237 void recv_buffered(const uint8_t* data, size_t size)
238 {
240 {
241 throw std::runtime_error("Called recv_buffered from incorrect thread");
242 }
243
244 if (can_recv())
245 {
246 pending_read.insert(pending_read.end(), data, data + size);
247 }
248
249 do_handshake();
250 }
251
252 void close()
253 {
254 status = closing;
256 {
257 auto msg = std::make_unique<::threading::Tmsg<EmptyMsg>>(&close_cb);
258 msg->data.self = this->shared_from_this();
259
261 execution_thread, std::move(msg));
262 }
263 else
264 {
265 // Close inline immediately
266 close_thread();
267 }
268 }
269
270 static void close_cb(std::unique_ptr<::threading::Tmsg<EmptyMsg>> msg)
271 {
272 msg->data.self->close_thread();
273 }
274
275 virtual void close_thread()
276 {
278 {
279 throw std::runtime_error("Called close_thread from incorrect thread");
280 }
281
282 switch (status)
283 {
284 case handshake:
285 {
286 LOG_TRACE_FMT("TLS {} closed during handshake", session_id);
287 stop(closed);
288 break;
289 }
290
291 case ready:
292 case closing:
293 {
294 int r = ctx->close();
295
296 switch (r)
297 {
300 {
301 LOG_TRACE_FMT("TLS {} has pending data ({})", session_id, r);
302 // FALLTHROUGH
303 }
304 case 0:
305 {
306 LOG_TRACE_FMT("TLS {} closed ({})", session_id, r);
307 stop(closed);
308 break;
309 }
310
311 default:
312 {
314 "TLS {} error on_close: {}",
317 stop(error);
318 break;
319 }
320 }
321 break;
322 }
323
324 default:
325 {
326 }
327 }
328 }
329
330 void send_raw(const uint8_t* data, size_t size)
331 {
333 {
334 auto msg =
335 std::make_unique<::threading::Tmsg<SendRecvMsg>>(&send_raw_cb);
336 msg->data.self = this->shared_from_this();
337 msg->data.data = std::vector<uint8_t>(data, data + size);
338
340 execution_thread, std::move(msg));
341 }
342 else
343 {
344 // Send inline immediately
345 send_raw_thread(data, size);
346 }
347 }
348
349 private:
350 static void send_raw_cb(std::unique_ptr<::threading::Tmsg<SendRecvMsg>> msg)
351 {
352 msg->data.self->send_raw_thread(
353 msg->data.data.data(), msg->data.data.size());
354 }
355
356 void send_raw_thread(const uint8_t* data, size_t size)
357 {
359 {
360 throw std::runtime_error(
361 "Called send_raw_thread from incorrect thread");
362 }
363 // Writes as much of the data as possible. If the data cannot all
364 // be written now, we store the remainder. We
365 // will try to send pending writes again whenever write() is called.
366 do_handshake();
367
368 if (status == handshake)
369 {
370 pending_write.insert(pending_write.end(), data, data + size);
371 return;
372 }
373
374 if (!can_send())
375 {
376 return;
377 }
378
379 pending_write.insert(pending_write.end(), data, data + size);
380
381 flush();
382 }
383
384 void send_buffered(const std::vector<uint8_t>& data)
385 {
387 {
388 throw std::runtime_error("Called send_buffered from incorrect thread");
389 }
390
391 pending_write.insert(pending_write.end(), data.begin(), data.end());
392 }
393
394 void flush()
395 {
397 {
398 throw std::runtime_error("Called flush from incorrect thread");
399 }
400
401 do_handshake();
402
403 if (!can_send())
404 {
405 return;
406 }
407
408 while (pending_write.size() > 0)
409 {
410 auto r = write_some(pending_write);
411
412 if (r > 0)
413 {
414 pending_write.erase(pending_write.begin(), pending_write.begin() + r);
415 }
416 else if (r == 0)
417 {
418 break;
419 }
420 else
421 {
422 LOG_TRACE_FMT("TLS session {} error on flush: {}", session_id, -r);
423 stop(error);
424 }
425 }
426 }
427
428 void do_handshake()
429 {
430 // This should be called when additional data is written to the
431 // input buffer, until the handshake is complete.
432 if (status != handshake)
433 {
434 return;
435 }
436
437 auto rc = ctx->handshake();
438
439 switch (rc)
440 {
441 case 0:
442 {
443 status = ready;
444 break;
445 }
446
449 break;
450
452 {
453 on_handshake_error(fmt::format(
454 "TLS {} verify error on handshake: {}",
457 stop(authfail);
458 break;
459 }
460
462 {
464 "TLS {} closed on handshake: {}",
467 stop(closed);
468 break;
469 }
470
472 {
473 auto err = ctx->get_verify_error();
474 on_handshake_error(fmt::format(
475 "TLS {} invalid cert on handshake: {} [{}]",
477 err,
479 stop(authfail);
480 return;
481 }
482
483 default:
484 {
485 on_handshake_error(fmt::format(
486 "TLS {} error on handshake: {}",
489 stop(error);
490 break;
491 }
492 }
493 }
494
495 int write_some(const std::vector<uint8_t>& data)
496 {
497 auto r = ctx->write(data.data(), data.size());
498
499 switch (r)
500 {
503 return 0;
504
505 default:
506 return r;
507 }
508 }
509
510 void stop(SessionStatus status_)
511 {
512 switch (status)
513 {
514 case closed:
515 case authfail:
516 case error:
517 return;
518
519 default:
520 {
521 }
522 }
523
524 status = status_;
525
526 switch (status)
527 {
528 case closing:
529 case closed:
530 {
532 ::tcp::tcp_stop,
533 to_host,
535 std::string("Session closed"));
536 break;
537 }
538
539 case authfail:
540 {
542 ::tcp::tcp_stop,
543 to_host,
545 std::string("Authentication failed"));
546 }
547 case error:
548 {
550 ::tcp::tcp_stop, to_host, session_id, std::string("Error"));
551 break;
552 }
553
554 default:
555 {
556 }
557 }
558 }
559
560 int handle_send(const uint8_t* buf, size_t len)
561 {
562 // Either write all of the data or none of it.
563 auto wrote = RINGBUFFER_TRY_WRITE_MESSAGE(
564 ::tcp::tcp_outbound,
565 to_host,
567 serializer::ByteRange{buf, len});
568
569 if (!wrote)
570 return TLS_WRITING;
571
572 return (int)len;
573 }
574
575 int handle_recv(uint8_t* buf, size_t len)
576 {
578 {
579 throw std::runtime_error("Called handle_recv from incorrect thread");
580 }
581 if (pending_read.size() > 0)
582 {
583 // Use the pending data vector. This is populated when the host
584 // writes a chunk larger than the size requested by the enclave.
585 size_t rd = std::min(len, pending_read.size());
586 ::memcpy(buf, pending_read.data(), rd);
587
588 if (rd >= pending_read.size())
589 {
590 pending_read.clear();
591 }
592 else
593 {
594 pending_read.erase(pending_read.begin(), pending_read.begin() + rd);
595 }
596
597 return (int)rd;
598 }
599
600 return TLS_READING;
601 }
602
603 static int send_callback(void* ctx, const unsigned char* buf, size_t len)
604 {
605 return reinterpret_cast<TLSSession*>(ctx)->handle_send(buf, len);
606 }
607
608 static int recv_callback(void* ctx, unsigned char* buf, size_t len)
609 {
610 return reinterpret_cast<TLSSession*>(ctx)->handle_recv(buf, len);
611 }
612
613 // These callbacks below are complex, using the callbacks above and
614 // manipulating OpenSSL's BIO objects accordingly. This is just so we can
615 // emulate what MbedTLS used to do.
616 // Now that we have removed it from the code, we can move the callbacks
617 // above to handle BIOs directly and hopefully remove the complexity below.
618 // This work will be carried out in #3429.
619 static long send_callback_openssl(
620 BIO* b,
621 int oper,
622 const char* argp,
623 size_t len,
624 int argi,
625 long argl,
626 int ret,
627 size_t* processed)
628 {
629 // Unused arguments
630 (void)argi;
631 (void)argl;
632 (void)argp;
633
634 if (ret && len > 0 && oper == (BIO_CB_WRITE | BIO_CB_RETURN))
635 {
636 // Flush BIO so the "pipe doesn't clog", but we don't use the
637 // data here, because 'argp' already has it.
638 BIO_flush(b);
639 size_t pending = BIO_pending(b);
640 if (pending)
641 BIO_reset(b);
642
643 // Pipe object
644 void* ctx = (BIO_get_callback_arg(b));
645 int put = send_callback(ctx, (const uint8_t*)argp, len);
646
647 // WANTS_WRITE
648 if (put == TLS_WRITING)
649 {
650 BIO_set_retry_write(b);
651 LOG_TRACE_FMT("TLS Session::send_cb() : WANTS_WRITE");
652 *processed = 0;
653 return -1;
654 }
655 else
656 {
657 LOG_TRACE_FMT("TLS Session::send_cb() : Put {} bytes", put);
658 }
659
660 // Update the number of bytes to external users
661 *processed = put;
662 }
663
664 // Unless we detected an error, the return value is always the same as the
665 // original operation.
666 return ret;
667 }
668
669 static long recv_callback_openssl(
670 BIO* b,
671 int oper,
672 const char* argp,
673 size_t len,
674 int argi,
675 long argl,
676 int ret,
677 size_t* processed)
678 {
679 // Unused arguments
680 (void)argi;
681 (void)argl;
682
683 if (ret == 1 && oper == (BIO_CB_CTRL | BIO_CB_RETURN))
684 {
685 // This callback may be fired at the end of large batches of TLS frames
686 // on OpenSSL 3.x. Note that processed == nullptr in this case, hence
687 // the early exit.
688 return 0;
689 }
690
691 if (ret && (oper == (BIO_CB_READ | BIO_CB_RETURN)))
692 {
693 // Pipe object
694 void* ctx = (BIO_get_callback_arg(b));
695 int got = recv_callback(ctx, (uint8_t*)argp, len);
696
697 // WANTS_READ
698 if (got == TLS_READING)
699 {
700 BIO_set_retry_read(b);
701 LOG_TRACE_FMT("TLS Session::recv_cb() : WANTS_READ");
702 *processed = 0;
703 return -1;
704 }
705 else
706 {
708 "TLS Session::recv_cb() : Got {} bytes of {}", got, len);
709 }
710
711 // If got less than requested, return WANT_READ
712 if ((size_t)got < len)
713 {
714 *processed = got;
715 return 1;
716 }
717
718 // Write to the actual BIO so SSL can use it
719 BIO_write_ex(b, argp, got, processed);
720
721 // The buffer should be enough, we can't return WANT_WRITE here
722 if ((size_t)got != *processed)
723 {
724 LOG_TRACE_FMT("TLS Session::recv_cb() : BIO error");
725 *processed = got;
726 return -1;
727 }
728
729 // If original return was -1 because it didn't find anything to read,
730 // return 1 to say we actually read something. This is common when the
731 // buffer is empty and needs an external read, so let's not log this.
732 if (got > 0 && ret < 0)
733 {
734 return 1;
735 }
736 }
737
738 // Unless we detected an error, the return value is always the same as the
739 // original operation.
740 return ret;
741 }
742 };
743}
Definition tls_session.h:28
SessionStatus get_status() const
Definition tls_session.h:92
void send_raw(const uint8_t *data, size_t size)
Definition tls_session.h:330
std::string hostname()
Definition tls_session.h:114
void recv_buffered(const uint8_t *data, size_t size)
Definition tls_session.h:237
virtual void close_thread()
Definition tls_session.h:275
std::function< void(std::string &&)> HandshakeErrorCB
Definition tls_session.h:30
size_t read(uint8_t *data, size_t size, bool exact=false)
Definition tls_session.h:134
TLSSession(int64_t session_id_, ringbuffer::AbstractWriterFactory &writer_factory_, std::unique_ptr< tls::Context > ctx_)
Definition tls_session.h:72
virtual ~TLSSession()
Definition tls_session.h:87
void on_handshake_error(std::string &&error_msg)
Definition tls_session.h:97
static void close_cb(std::unique_ptr<::threading::Tmsg< EmptyMsg > > msg)
Definition tls_session.h:270
void close()
Definition tls_session.h:252
std::vector< uint8_t > peer_cert()
Definition tls_session.h:124
ringbuffer::WriterPtr to_host
Definition tls_session.h:33
size_t execution_thread
Definition tls_session.h:35
void set_handshake_error_cb(HandshakeErrorCB &&cb)
Definition tls_session.h:109
::tcp::ConnID session_id
Definition tls_session.h:34
Definition ring_buffer_types.h:153
static ThreadMessaging & instance()
Definition thread_messaging.h:283
void add_task(uint16_t tid, std::unique_ptr< Tmsg< Payload > > msg)
Definition thread_messaging.h:318
uint16_t get_execution_thread(uint32_t i)
Definition thread_messaging.h:371
#define LOG_TRACE_FMT
Definition logger.h:356
uint16_t get_current_thread_id()
Definition thread_local.cpp:15
Definition app_interface.h:14
SessionStatus
Definition tls_session.h:18
@ closed
Definition tls_session.h:22
@ authfail
Definition tls_session.h:23
@ error
Definition tls_session.h:24
@ ready
Definition tls_session.h:20
@ closing
Definition tls_session.h:21
@ handshake
Definition tls_session.h:19
std::shared_ptr< AbstractWriter > WriterPtr
Definition ring_buffer_types.h:150
STL namespace.
int64_t ConnID
Definition msg_types.h:9
std::string error_string(int ec)
Definition tls.h:32
#define RINGBUFFER_TRY_WRITE_MESSAGE(MSG,...)
Definition ring_buffer_types.h:258
#define RINGBUFFER_WRITE_MESSAGE(MSG,...)
Definition ring_buffer_types.h:255
Definition serializer.h:27
Definition thread_messaging.h:27
#define TLS_ERR_X509_VERIFY
Definition tls.h:24
#define TLS_READING
Definition tls.h:14
#define TLS_ERR_WANT_WRITE
Definition tls.h:17
#define TLS_ERR_WANT_READ
Definition tls.h:16
#define TLS_WRITING
Definition tls.h:15
#define TLS_ERR_CONN_CLOSE_NOTIFY
Definition tls.h:18
#define TLS_ERR_NEED_CERT
Definition tls.h:19