1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use core::fmt;
15 use core::marker::PhantomData;
16 use core::mem::ManuallyDrop;
17 use std::io::{self, Read, Write};
18 use std::panic::resume_unwind;
19 use std::ptr;
20 
21 use libc::c_int;
22 
23 use super::{InternalError, Ssl, SslError, SslErrorCode, SslRef};
24 use crate::c_openssl::bio::{self, get_error, get_panic, get_stream_mut, get_stream_ref};
25 use crate::c_openssl::error::ErrorStack;
26 use crate::c_openssl::ffi::ssl::{SSL_connect, SSL_set_bio, SSL_shutdown};
27 use crate::c_openssl::foreign::Foreign;
28 use crate::util::base64::encode;
29 use crate::util::c_openssl::bio::BioMethod;
30 use crate::util::c_openssl::error::VerifyError;
31 use crate::util::c_openssl::error::VerifyKind::PubKeyPinning;
32 use crate::util::c_openssl::ffi::ssl::SSL;
33 use crate::util::c_openssl::ffi::x509::{i2d_X509_PUBKEY, X509_free, X509_get_X509_PUBKEY, C_X509};
34 use crate::util::c_openssl::verify::sha256_digest;
35 
36 /// A TLS session over a stream.
37 pub struct SslStream<S> {
38     pub(crate) ssl: ManuallyDrop<Ssl>,
39     method: ManuallyDrop<BioMethod>,
40     pinned_pubkey: Option<String>,
41     p: PhantomData<S>,
42 }
43 
44 impl<S> SslStream<S> {
get_error(&mut self, err: c_int) -> SslError45     pub(crate) fn get_error(&mut self, err: c_int) -> SslError {
46         self.check_panic();
47         let code = self.ssl.get_error(err);
48         let internal = match code {
49             SslErrorCode::SSL => {
50                 let e = ErrorStack::get();
51                 Some(InternalError::Ssl(e))
52             }
53             SslErrorCode::SYSCALL => {
54                 let error = ErrorStack::get();
55                 if error.errors().is_empty() {
56                     self.get_bio_error().map(InternalError::Io)
57                 } else {
58                     Some(InternalError::Ssl(error))
59                 }
60             }
61             SslErrorCode::WANT_WRITE | SslErrorCode::WANT_READ => {
62                 self.get_bio_error().map(InternalError::Io)
63             }
64             _ => None,
65         };
66         SslError { code, internal }
67     }
68 
check_panic(&mut self)69     fn check_panic(&mut self) {
70         if let Some(err) = unsafe { get_panic::<S>(self.ssl.get_raw_bio()) } {
71             resume_unwind(err)
72         }
73     }
74 
get_bio_error(&mut self) -> Option<io::Error>75     fn get_bio_error(&mut self) -> Option<io::Error> {
76         unsafe { get_error::<S>(self.ssl.get_raw_bio()) }
77     }
78 
get_ref(&self) -> &S79     pub(crate) fn get_ref(&self) -> &S {
80         unsafe {
81             let bio = self.ssl.get_raw_bio();
82             get_stream_ref(bio)
83         }
84     }
85 
get_mut(&mut self) -> &mut S86     pub(crate) fn get_mut(&mut self) -> &mut S {
87         unsafe {
88             let bio = self.ssl.get_raw_bio();
89             get_stream_mut(bio)
90         }
91     }
92 
ssl(&self) -> &SslRef93     pub(crate) fn ssl(&self) -> &SslRef {
94         &self.ssl
95     }
96 }
97 
98 impl<S> fmt::Debug for SslStream<S>
99 where
100     S: fmt::Debug,
101 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result102     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103         write!(f, "stream[{:?}], {:?}", &self.get_ref(), &self.ssl())
104     }
105 }
106 
107 impl<S> Drop for SslStream<S> {
drop(&mut self)108     fn drop(&mut self) {
109         unsafe {
110             ManuallyDrop::drop(&mut self.ssl);
111             ManuallyDrop::drop(&mut self.method);
112         }
113     }
114 }
115 
116 impl<S: Read + Write> SslStream<S> {
ssl_read(&mut self, buf: &[u8]) -> Result<usize, SslError>117     pub(crate) fn ssl_read(&mut self, buf: &[u8]) -> Result<usize, SslError> {
118         if buf.is_empty() {
119             return Ok(0);
120         }
121         let ret = self.ssl.read(buf);
122         if ret > 0 {
123             Ok(ret as usize)
124         } else {
125             Err(self.get_error(ret))
126         }
127     }
128 
ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError>129     pub(crate) fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError> {
130         if buf.is_empty() {
131             return Ok(0);
132         }
133         let ret = self.ssl.write(buf);
134         if ret > 0 {
135             Ok(ret as usize)
136         } else {
137             Err(self.get_error(ret))
138         }
139     }
140 
new_base( ssl: Ssl, stream: S, pinned_pubkey: Option<String>, ) -> Result<Self, ErrorStack>141     pub(crate) fn new_base(
142         ssl: Ssl,
143         stream: S,
144         pinned_pubkey: Option<String>,
145     ) -> Result<Self, ErrorStack> {
146         unsafe {
147             let (bio, method) = bio::new(stream)?;
148             SSL_set_bio(ssl.as_ptr(), bio, bio);
149 
150             Ok(SslStream {
151                 ssl: ManuallyDrop::new(ssl),
152                 method: ManuallyDrop::new(method),
153                 pinned_pubkey,
154                 p: PhantomData,
155             })
156         }
157     }
158 
connect(&mut self) -> Result<(), SslError>159     pub(crate) fn connect(&mut self) -> Result<(), SslError> {
160         let ret = unsafe { SSL_connect(self.ssl.as_ptr()) };
161         if ret > 0 {
162             match &self.pinned_pubkey {
163                 None => {}
164                 Some(key) => {
165                     verify_server_cert(self.ssl.as_ptr(), key.as_str())?;
166                 }
167             }
168             Ok(())
169         } else {
170             Err(self.get_error(ret))
171         }
172     }
173 
shutdown(&mut self) -> Result<ShutdownResult, SslError>174     pub(crate) fn shutdown(&mut self) -> Result<ShutdownResult, SslError> {
175         unsafe {
176             match SSL_shutdown(self.ssl.as_ptr()) {
177                 0 => Ok(ShutdownResult::Sent),
178                 1 => Ok(ShutdownResult::Received),
179                 n => Err(self.get_error(n)),
180             }
181         }
182     }
183 }
184 
185 impl<S: Read + Write> Read for SslStream<S> {
186     // ssl_read
read(&mut self, buf: &mut [u8]) -> io::Result<usize>187     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
188         loop {
189             match self.ssl_read(buf) {
190                 Ok(n) => return Ok(n),
191                 // The TLS/SSL peer has closed the connection for writing by sending
192                 // the close_notify alert. No more data can be read.
193                 // Does not necessarily indicate that the underlying transport has been closed.
194                 Err(ref e) if e.code == SslErrorCode::ZERO_RETURN => return Ok(0),
195                 // A non-recoverable, fatal error in the SSL library occurred, usually a protocol
196                 // error.
197                 Err(ref e) if e.code == SslErrorCode::SYSCALL && e.get_io_error().is_none() => {
198                     return Ok(0)
199                 }
200                 // When the last operation was a read operation from a nonblocking BIO.
201                 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
202                 // Other error.
203                 Err(err) => {
204                     return Err(err
205                         .into_io_error()
206                         .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)))
207                 }
208             };
209         }
210     }
211 }
212 
213 impl<S: Read + Write> Write for SslStream<S> {
214     // ssl_write
write(&mut self, buf: &[u8]) -> io::Result<usize>215     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
216         loop {
217             match self.ssl_write(buf) {
218                 Ok(n) => return Ok(n),
219                 // When the last operation was a read operation from a nonblocking BIO.
220                 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
221                 Err(err) => {
222                     return Err(err
223                         .into_io_error()
224                         .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)));
225                 }
226             }
227         }
228     }
229 
230     // S.flush()
flush(&mut self) -> io::Result<()>231     fn flush(&mut self) -> io::Result<()> {
232         self.get_mut().flush()
233     }
234 }
235 
236 /// An SSL stream midway through the handshake process.
237 #[derive(Debug)]
238 pub(crate) struct MidHandshakeSslStream<S> {
239     pub(crate) _stream: SslStream<S>,
240     pub(crate) error: SslError,
241 }
242 
243 impl<S> MidHandshakeSslStream<S> {
error(&self) -> &SslError244     pub(crate) fn error(&self) -> &SslError {
245         &self.error
246     }
247 }
248 
249 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
250 pub(crate) enum ShutdownResult {
251     Sent,
252     Received,
253 }
254 
255 // TODO The SSLError thrown here is meaningless and has no information.
verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError>256 fn verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> {
257     #[cfg(feature = "c_openssl_3_0")]
258     use crate::util::c_openssl::ffi::ssl::SSL_get1_peer_certificate;
259     #[cfg(feature = "c_openssl_1_1")]
260     use crate::util::c_openssl::ffi::ssl::SSL_get_peer_certificate;
261 
262     let certificate = unsafe {
263         #[cfg(feature = "c_openssl_3_0")]
264         {
265             SSL_get1_peer_certificate(ssl)
266         }
267         #[cfg(feature = "c_openssl_1_1")]
268         {
269             SSL_get_peer_certificate(ssl)
270         }
271     };
272     if certificate.is_null() {
273         return Err(SslError {
274             code: SslErrorCode::SSL,
275             internal: Some(InternalError::Ssl(ErrorStack::get())),
276         });
277     }
278 
279     let size_1 = unsafe { i2d_X509_PUBKEY(X509_get_X509_PUBKEY(certificate), ptr::null_mut()) };
280     if size_1 < 1 {
281         unsafe { X509_free(certificate) };
282         return Err(SslError {
283             code: SslErrorCode::SSL,
284             internal: Some(InternalError::Ssl(ErrorStack::get())),
285         });
286     }
287     let key = vec![0u8; size_1 as usize];
288     let size_2 = unsafe { i2d_X509_PUBKEY(X509_get_X509_PUBKEY(certificate), &mut key.as_ptr()) };
289 
290     if size_1 != size_2 || size_2 <= 0 {
291         unsafe { X509_free(certificate) };
292         return Err(SslError {
293             code: SslErrorCode::SSL,
294             internal: Some(InternalError::Ssl(ErrorStack::get())),
295         });
296     }
297 
298     // sha256 length.
299     let mut digest = [0u8; 32];
300     unsafe { sha256_digest(key.as_slice(), size_2, &mut digest)? }
301 
302     compare_pinned_digest(&digest, pinned_key.as_bytes(), certificate)
303 }
304 
compare_pinned_digest( digest: &[u8], pinned_key: &[u8], certificate: *mut C_X509, ) -> Result<(), SslError>305 fn compare_pinned_digest(
306     digest: &[u8],
307     pinned_key: &[u8],
308     certificate: *mut C_X509,
309 ) -> Result<(), SslError> {
310     let base64_digest = encode(digest);
311     let mut user_bytes = pinned_key;
312 
313     let mut begin;
314     let mut end;
315     let prefix = b"sha256//";
316     let suffix = b";sha256//";
317     while !user_bytes.is_empty() {
318         begin = match user_bytes
319             .windows(prefix.len())
320             .position(|window| window == prefix)
321         {
322             None => {
323                 break;
324             }
325             Some(index) => index + 8,
326         };
327         end = match user_bytes
328             .windows(suffix.len())
329             .position(|window| window == suffix)
330         {
331             None => user_bytes.len(),
332             Some(index) => index,
333         };
334 
335         let bytes = &user_bytes[begin..end];
336         if bytes.eq(base64_digest.as_slice()) {
337             unsafe { X509_free(certificate) };
338             return Ok(());
339         }
340 
341         if end != user_bytes.len() {
342             user_bytes = &user_bytes[end + 1..];
343         } else {
344             user_bytes = &user_bytes[end..];
345         }
346     }
347 
348     unsafe { X509_free(certificate) };
349     Err(SslError {
350         code: SslErrorCode::SSL,
351         internal: Some(InternalError::User(VerifyError::from_msg(
352             PubKeyPinning,
353             "Pinned public key verification failed.",
354         ))),
355     })
356 }
357