1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use core::{cmp, ffi, fmt, str};
15 use std::ffi::{c_uchar, CString};
16 #[cfg(feature = "sync")]
17 use std::io::{Read, Write};
18 use std::ptr::null;
19 use std::slice::from_raw_parts;
20
21 use libc::{c_char, c_int, c_long, c_void};
22
23 #[cfg(feature = "sync")]
24 use super::error::HandshakeError;
25 #[cfg(feature = "sync")]
26 use super::SslStream;
27 use super::{SslContext, SslErrorCode};
28 use crate::c_openssl::check_ret;
29 use crate::c_openssl::ffi::bio::BIO;
30 use crate::c_openssl::ffi::ssl::{
31 SSL_ctrl, SSL_get0_param, SSL_get_error, SSL_get_rbio, SSL_get_verify_result, SSL_read,
32 SSL_state_string_long, SSL_write,
33 };
34 use crate::c_openssl::foreign::ForeignRef;
35 use crate::c_openssl::x509::{
36 X509VerifyParamRef, X509VerifyResult, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS,
37 };
38 use crate::util::c_openssl::check_ptr;
39 use crate::util::c_openssl::error::ErrorStack;
40 use crate::util::c_openssl::ffi::ssl::{SSL_free, SSL_get0_alpn_selected, SSL_new, SSL};
41 use crate::util::c_openssl::foreign::Foreign;
42
43 foreign_type!(
44 type CStruct = SSL;
45 fn drop = SSL_free;
46 /// The main SSL/TLS structure.
47 pub(crate) struct Ssl;
48 pub(crate) struct SslRef;
49 );
50
51 impl Ssl {
new(ctx: &SslContext) -> Result<Ssl, ErrorStack>52 pub(crate) fn new(ctx: &SslContext) -> Result<Ssl, ErrorStack> {
53 unsafe {
54 let ptr = check_ptr(SSL_new(ctx.as_ptr()))?;
55 Ok(Ssl::from_ptr(ptr))
56 }
57 }
58
59 /// Client connect to Server.
60 /// only `sync` use.
61 #[cfg(feature = "sync")]
connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,62 pub(crate) fn connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
63 where
64 S: Read + Write,
65 {
66 use super::MidHandshakeSslStream;
67 use crate::c_openssl::ffi::ssl::SSL_connect;
68
69 let mut stream = SslStream::new_base(self, stream, None)?;
70 let ret = unsafe { SSL_connect(stream.ssl.as_ptr()) };
71 if ret > 0 {
72 Ok(stream)
73 } else {
74 let error = stream.get_error(ret);
75 match error.code {
76 SslErrorCode::WANT_READ | SslErrorCode::WANT_WRITE => {
77 Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
78 _stream: stream,
79 error,
80 }))
81 }
82 _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
83 _stream: stream,
84 error,
85 })),
86 }
87 }
88 }
89 }
90
91 impl SslRef {
get_error(&self, err: c_int) -> SslErrorCode92 pub(crate) fn get_error(&self, err: c_int) -> SslErrorCode {
93 unsafe { SslErrorCode::from_int(SSL_get_error(self.as_ptr(), err)) }
94 }
95
ssl_status(&self) -> &'static str96 fn ssl_status(&self) -> &'static str {
97 let status = unsafe {
98 let ptr = SSL_state_string_long(self.as_ptr());
99 ffi::CStr::from_ptr(ptr as *const _)
100 };
101 str::from_utf8(status.to_bytes()).unwrap_or_default()
102 }
103
verify_result(&self) -> X509VerifyResult104 pub(crate) fn verify_result(&self) -> X509VerifyResult {
105 unsafe { X509VerifyResult::from_raw(SSL_get_verify_result(self.as_ptr()) as c_int) }
106 }
107
get_raw_bio(&self) -> *mut BIO108 pub(crate) fn get_raw_bio(&self) -> *mut BIO {
109 unsafe { SSL_get_rbio(self.as_ptr()) }
110 }
111
read(&mut self, buf: &[u8]) -> c_int112 pub(crate) fn read(&mut self, buf: &[u8]) -> c_int {
113 let len = cmp::min(c_int::MAX as usize, buf.len()) as c_int;
114 unsafe { SSL_read(self.as_ptr(), buf.as_ptr() as *mut c_void, len) }
115 }
116
write(&mut self, buf: &[u8]) -> c_int117 pub(crate) fn write(&mut self, buf: &[u8]) -> c_int {
118 let len = cmp::min(c_int::MAX as usize, buf.len()) as c_int;
119 unsafe { SSL_write(self.as_ptr(), buf.as_ptr() as *const c_void, len) }
120 }
121
set_host_name_in_sni(&mut self, name: &str) -> Result<(), ErrorStack>122 pub(crate) fn set_host_name_in_sni(&mut self, name: &str) -> Result<(), ErrorStack> {
123 let name = match CString::new(name) {
124 Ok(name) => name,
125 Err(_) => return Err(ErrorStack::get()),
126 };
127 check_ret(
128 unsafe { ssl_set_tlsext_host_name(self.as_ptr(), name.as_ptr() as *mut _) } as c_int,
129 )
130 .map(|_| ())
131 }
132
param_mut(&mut self) -> &mut X509VerifyParamRef133 pub(crate) fn param_mut(&mut self) -> &mut X509VerifyParamRef {
134 unsafe { X509VerifyParamRef::from_ptr_mut(SSL_get0_param(self.as_ptr())) }
135 }
136
set_verify_hostname(&mut self, host_name: &str) -> Result<(), ErrorStack>137 pub(crate) fn set_verify_hostname(&mut self, host_name: &str) -> Result<(), ErrorStack> {
138 let param = self.param_mut();
139 param.set_hostflags(X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
140 match host_name.parse() {
141 Ok(ip) => param.set_ip(ip),
142 Err(_) => param.set_host(host_name),
143 }
144 }
145
negotiated_alpn_protocol(&self) -> Option<&[u8]>146 pub(crate) fn negotiated_alpn_protocol(&self) -> Option<&[u8]> {
147 let mut data = null() as *const c_uchar;
148 let mut len = 0_u32;
149 unsafe {
150 SSL_get0_alpn_selected(self.as_ptr(), &mut data, &mut len);
151 if data.is_null() {
152 None
153 } else {
154 Some(from_raw_parts(data, len as usize))
155 }
156 }
157 }
158 }
159
160 impl fmt::Debug for SslRef {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 write!(
163 f,
164 "Ssl[state: {}, verify result: {}]",
165 &self.ssl_status(),
166 &self.verify_result()
167 )
168 }
169 }
170
171 const SSL_CTRL_SET_TLSEXT_HOSTNAME: c_int = 0x37;
172 const TLSEXT_NAMETYPE_HOST_NAME: c_int = 0x0;
173
ssl_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long174 unsafe fn ssl_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long {
175 SSL_ctrl(
176 s,
177 SSL_CTRL_SET_TLSEXT_HOSTNAME,
178 TLSEXT_NAMETYPE_HOST_NAME as c_long,
179 name as *mut c_void,
180 )
181 }
182