1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use core::{cmp, ffi, fmt, str};
15 use std::ffi::{c_uchar, CString};
16 #[cfg(feature = "sync")]
17 use std::io::{Read, Write};
18 use std::ptr::null;
19 use std::slice::from_raw_parts;
20 
21 use libc::{c_char, c_int, c_long, c_void};
22 
23 #[cfg(feature = "sync")]
24 use super::error::HandshakeError;
25 #[cfg(feature = "sync")]
26 use super::SslStream;
27 use super::{SslContext, SslErrorCode};
28 use crate::c_openssl::check_ret;
29 use crate::c_openssl::ffi::bio::BIO;
30 use crate::c_openssl::ffi::ssl::{
31     SSL_ctrl, SSL_get0_param, SSL_get_error, SSL_get_rbio, SSL_get_verify_result, SSL_read,
32     SSL_state_string_long, SSL_write,
33 };
34 use crate::c_openssl::foreign::ForeignRef;
35 use crate::c_openssl::x509::{
36     X509VerifyParamRef, X509VerifyResult, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS,
37 };
38 use crate::util::c_openssl::check_ptr;
39 use crate::util::c_openssl::error::ErrorStack;
40 use crate::util::c_openssl::ffi::ssl::{SSL_free, SSL_get0_alpn_selected, SSL_new, SSL};
41 use crate::util::c_openssl::foreign::Foreign;
42 
43 foreign_type!(
44     type CStruct = SSL;
45     fn drop = SSL_free;
46     /// The main SSL/TLS structure.
47     pub(crate) struct Ssl;
48     pub(crate) struct SslRef;
49 );
50 
51 impl Ssl {
new(ctx: &SslContext) -> Result<Ssl, ErrorStack>52     pub(crate) fn new(ctx: &SslContext) -> Result<Ssl, ErrorStack> {
53         unsafe {
54             let ptr = check_ptr(SSL_new(ctx.as_ptr()))?;
55             Ok(Ssl::from_ptr(ptr))
56         }
57     }
58 
59     /// Client connect to Server.
60     /// only `sync` use.
61     #[cfg(feature = "sync")]
connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,62     pub(crate) fn connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
63     where
64         S: Read + Write,
65     {
66         use super::MidHandshakeSslStream;
67         use crate::c_openssl::ffi::ssl::SSL_connect;
68 
69         let mut stream = SslStream::new_base(self, stream, None)?;
70         let ret = unsafe { SSL_connect(stream.ssl.as_ptr()) };
71         if ret > 0 {
72             Ok(stream)
73         } else {
74             let error = stream.get_error(ret);
75             match error.code {
76                 SslErrorCode::WANT_READ | SslErrorCode::WANT_WRITE => {
77                     Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
78                         _stream: stream,
79                         error,
80                     }))
81                 }
82                 _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
83                     _stream: stream,
84                     error,
85                 })),
86             }
87         }
88     }
89 }
90 
91 impl SslRef {
get_error(&self, err: c_int) -> SslErrorCode92     pub(crate) fn get_error(&self, err: c_int) -> SslErrorCode {
93         unsafe { SslErrorCode::from_int(SSL_get_error(self.as_ptr(), err)) }
94     }
95 
ssl_status(&self) -> &'static str96     fn ssl_status(&self) -> &'static str {
97         let status = unsafe {
98             let ptr = SSL_state_string_long(self.as_ptr());
99             ffi::CStr::from_ptr(ptr as *const _)
100         };
101         str::from_utf8(status.to_bytes()).unwrap_or_default()
102     }
103 
verify_result(&self) -> X509VerifyResult104     pub(crate) fn verify_result(&self) -> X509VerifyResult {
105         unsafe { X509VerifyResult::from_raw(SSL_get_verify_result(self.as_ptr()) as c_int) }
106     }
107 
get_raw_bio(&self) -> *mut BIO108     pub(crate) fn get_raw_bio(&self) -> *mut BIO {
109         unsafe { SSL_get_rbio(self.as_ptr()) }
110     }
111 
read(&mut self, buf: &[u8]) -> c_int112     pub(crate) fn read(&mut self, buf: &[u8]) -> c_int {
113         let len = cmp::min(c_int::MAX as usize, buf.len()) as c_int;
114         unsafe { SSL_read(self.as_ptr(), buf.as_ptr() as *mut c_void, len) }
115     }
116 
write(&mut self, buf: &[u8]) -> c_int117     pub(crate) fn write(&mut self, buf: &[u8]) -> c_int {
118         let len = cmp::min(c_int::MAX as usize, buf.len()) as c_int;
119         unsafe { SSL_write(self.as_ptr(), buf.as_ptr() as *const c_void, len) }
120     }
121 
set_host_name_in_sni(&mut self, name: &str) -> Result<(), ErrorStack>122     pub(crate) fn set_host_name_in_sni(&mut self, name: &str) -> Result<(), ErrorStack> {
123         let name = match CString::new(name) {
124             Ok(name) => name,
125             Err(_) => return Err(ErrorStack::get()),
126         };
127         check_ret(
128             unsafe { ssl_set_tlsext_host_name(self.as_ptr(), name.as_ptr() as *mut _) } as c_int,
129         )
130         .map(|_| ())
131     }
132 
param_mut(&mut self) -> &mut X509VerifyParamRef133     pub(crate) fn param_mut(&mut self) -> &mut X509VerifyParamRef {
134         unsafe { X509VerifyParamRef::from_ptr_mut(SSL_get0_param(self.as_ptr())) }
135     }
136 
set_verify_hostname(&mut self, host_name: &str) -> Result<(), ErrorStack>137     pub(crate) fn set_verify_hostname(&mut self, host_name: &str) -> Result<(), ErrorStack> {
138         let param = self.param_mut();
139         param.set_hostflags(X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
140         match host_name.parse() {
141             Ok(ip) => param.set_ip(ip),
142             Err(_) => param.set_host(host_name),
143         }
144     }
145 
negotiated_alpn_protocol(&self) -> Option<&[u8]>146     pub(crate) fn negotiated_alpn_protocol(&self) -> Option<&[u8]> {
147         let mut data = null() as *const c_uchar;
148         let mut len = 0_u32;
149         unsafe {
150             SSL_get0_alpn_selected(self.as_ptr(), &mut data, &mut len);
151             if data.is_null() {
152                 None
153             } else {
154                 Some(from_raw_parts(data, len as usize))
155             }
156         }
157     }
158 }
159 
160 impl fmt::Debug for SslRef {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result161     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162         write!(
163             f,
164             "Ssl[state: {}, verify result: {}]",
165             &self.ssl_status(),
166             &self.verify_result()
167         )
168     }
169 }
170 
171 const SSL_CTRL_SET_TLSEXT_HOSTNAME: c_int = 0x37;
172 const TLSEXT_NAMETYPE_HOST_NAME: c_int = 0x0;
173 
ssl_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long174 unsafe fn ssl_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long {
175     SSL_ctrl(
176         s,
177         SSL_CTRL_SET_TLSEXT_HOSTNAME,
178         TLSEXT_NAMETYPE_HOST_NAME as c_long,
179         name as *mut c_void,
180     )
181 }
182