1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use core::fmt;
15 use core::marker::PhantomData;
16 use core::mem::ManuallyDrop;
17 use std::io::{self, Read, Write};
18 use std::panic::resume_unwind;
19 use std::ptr;
20
21 use libc::c_int;
22
23 use super::{InternalError, Ssl, SslError, SslErrorCode, SslRef};
24 use crate::c_openssl::bio::{self, get_error, get_panic, get_stream_mut, get_stream_ref};
25 use crate::c_openssl::error::ErrorStack;
26 use crate::c_openssl::ffi::ssl::{SSL_connect, SSL_set_bio, SSL_shutdown};
27 use crate::c_openssl::foreign::Foreign;
28 use crate::util::base64::encode;
29 use crate::util::c_openssl::bio::BioMethod;
30 use crate::util::c_openssl::error::VerifyError;
31 use crate::util::c_openssl::error::VerifyKind::PubKeyPinning;
32 use crate::util::c_openssl::ffi::ssl::SSL;
33 use crate::util::c_openssl::ffi::x509::{i2d_X509_PUBKEY, X509_free, X509_get_X509_PUBKEY, C_X509};
34 use crate::util::c_openssl::verify::sha256_digest;
35
36 /// A TLS session over a stream.
37 pub struct SslStream<S> {
38 pub(crate) ssl: ManuallyDrop<Ssl>,
39 method: ManuallyDrop<BioMethod>,
40 pinned_pubkey: Option<String>,
41 p: PhantomData<S>,
42 }
43
44 impl<S> SslStream<S> {
get_error(&mut self, err: c_int) -> SslError45 pub(crate) fn get_error(&mut self, err: c_int) -> SslError {
46 self.check_panic();
47 let code = self.ssl.get_error(err);
48 let internal = match code {
49 SslErrorCode::SSL => {
50 let e = ErrorStack::get();
51 Some(InternalError::Ssl(e))
52 }
53 SslErrorCode::SYSCALL => {
54 let error = ErrorStack::get();
55 if error.errors().is_empty() {
56 self.get_bio_error().map(InternalError::Io)
57 } else {
58 Some(InternalError::Ssl(error))
59 }
60 }
61 SslErrorCode::WANT_WRITE | SslErrorCode::WANT_READ => {
62 self.get_bio_error().map(InternalError::Io)
63 }
64 _ => None,
65 };
66 SslError { code, internal }
67 }
68
check_panic(&mut self)69 fn check_panic(&mut self) {
70 if let Some(err) = unsafe { get_panic::<S>(self.ssl.get_raw_bio()) } {
71 resume_unwind(err)
72 }
73 }
74
get_bio_error(&mut self) -> Option<io::Error>75 fn get_bio_error(&mut self) -> Option<io::Error> {
76 unsafe { get_error::<S>(self.ssl.get_raw_bio()) }
77 }
78
get_ref(&self) -> &S79 pub(crate) fn get_ref(&self) -> &S {
80 unsafe {
81 let bio = self.ssl.get_raw_bio();
82 get_stream_ref(bio)
83 }
84 }
85
get_mut(&mut self) -> &mut S86 pub(crate) fn get_mut(&mut self) -> &mut S {
87 unsafe {
88 let bio = self.ssl.get_raw_bio();
89 get_stream_mut(bio)
90 }
91 }
92
ssl(&self) -> &SslRef93 pub(crate) fn ssl(&self) -> &SslRef {
94 &self.ssl
95 }
96 }
97
98 impl<S> fmt::Debug for SslStream<S>
99 where
100 S: fmt::Debug,
101 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 write!(f, "stream[{:?}], {:?}", &self.get_ref(), &self.ssl())
104 }
105 }
106
107 impl<S> Drop for SslStream<S> {
drop(&mut self)108 fn drop(&mut self) {
109 unsafe {
110 ManuallyDrop::drop(&mut self.ssl);
111 ManuallyDrop::drop(&mut self.method);
112 }
113 }
114 }
115
116 impl<S: Read + Write> SslStream<S> {
ssl_read(&mut self, buf: &[u8]) -> Result<usize, SslError>117 pub(crate) fn ssl_read(&mut self, buf: &[u8]) -> Result<usize, SslError> {
118 if buf.is_empty() {
119 return Ok(0);
120 }
121 let ret = self.ssl.read(buf);
122 if ret > 0 {
123 Ok(ret as usize)
124 } else {
125 Err(self.get_error(ret))
126 }
127 }
128
ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError>129 pub(crate) fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, SslError> {
130 if buf.is_empty() {
131 return Ok(0);
132 }
133 let ret = self.ssl.write(buf);
134 if ret > 0 {
135 Ok(ret as usize)
136 } else {
137 Err(self.get_error(ret))
138 }
139 }
140
new_base( ssl: Ssl, stream: S, pinned_pubkey: Option<String>, ) -> Result<Self, ErrorStack>141 pub(crate) fn new_base(
142 ssl: Ssl,
143 stream: S,
144 pinned_pubkey: Option<String>,
145 ) -> Result<Self, ErrorStack> {
146 unsafe {
147 let (bio, method) = bio::new(stream)?;
148 SSL_set_bio(ssl.as_ptr(), bio, bio);
149
150 Ok(SslStream {
151 ssl: ManuallyDrop::new(ssl),
152 method: ManuallyDrop::new(method),
153 pinned_pubkey,
154 p: PhantomData,
155 })
156 }
157 }
158
connect(&mut self) -> Result<(), SslError>159 pub(crate) fn connect(&mut self) -> Result<(), SslError> {
160 let ret = unsafe { SSL_connect(self.ssl.as_ptr()) };
161 if ret > 0 {
162 match &self.pinned_pubkey {
163 None => {}
164 Some(key) => {
165 verify_server_cert(self.ssl.as_ptr(), key.as_str())?;
166 }
167 }
168 Ok(())
169 } else {
170 Err(self.get_error(ret))
171 }
172 }
173
shutdown(&mut self) -> Result<ShutdownResult, SslError>174 pub(crate) fn shutdown(&mut self) -> Result<ShutdownResult, SslError> {
175 unsafe {
176 match SSL_shutdown(self.ssl.as_ptr()) {
177 0 => Ok(ShutdownResult::Sent),
178 1 => Ok(ShutdownResult::Received),
179 n => Err(self.get_error(n)),
180 }
181 }
182 }
183 }
184
185 impl<S: Read + Write> Read for SslStream<S> {
186 // ssl_read
read(&mut self, buf: &mut [u8]) -> io::Result<usize>187 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
188 loop {
189 match self.ssl_read(buf) {
190 Ok(n) => return Ok(n),
191 // The TLS/SSL peer has closed the connection for writing by sending
192 // the close_notify alert. No more data can be read.
193 // Does not necessarily indicate that the underlying transport has been closed.
194 Err(ref e) if e.code == SslErrorCode::ZERO_RETURN => return Ok(0),
195 // A non-recoverable, fatal error in the SSL library occurred, usually a protocol
196 // error.
197 Err(ref e) if e.code == SslErrorCode::SYSCALL && e.get_io_error().is_none() => {
198 return Ok(0)
199 }
200 // When the last operation was a read operation from a nonblocking BIO.
201 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
202 // Other error.
203 Err(err) => {
204 return Err(err
205 .into_io_error()
206 .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)))
207 }
208 };
209 }
210 }
211 }
212
213 impl<S: Read + Write> Write for SslStream<S> {
214 // ssl_write
write(&mut self, buf: &[u8]) -> io::Result<usize>215 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
216 loop {
217 match self.ssl_write(buf) {
218 Ok(n) => return Ok(n),
219 // When the last operation was a read operation from a nonblocking BIO.
220 Err(ref e) if e.code == SslErrorCode::WANT_READ && e.get_io_error().is_none() => {}
221 Err(err) => {
222 return Err(err
223 .into_io_error()
224 .unwrap_or_else(|err| io::Error::new(io::ErrorKind::Other, err)));
225 }
226 }
227 }
228 }
229
230 // S.flush()
flush(&mut self) -> io::Result<()>231 fn flush(&mut self) -> io::Result<()> {
232 self.get_mut().flush()
233 }
234 }
235
236 /// An SSL stream midway through the handshake process.
237 #[derive(Debug)]
238 pub(crate) struct MidHandshakeSslStream<S> {
239 pub(crate) _stream: SslStream<S>,
240 pub(crate) error: SslError,
241 }
242
243 impl<S> MidHandshakeSslStream<S> {
error(&self) -> &SslError244 pub(crate) fn error(&self) -> &SslError {
245 &self.error
246 }
247 }
248
249 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
250 pub(crate) enum ShutdownResult {
251 Sent,
252 Received,
253 }
254
255 // TODO The SSLError thrown here is meaningless and has no information.
verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError>256 fn verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> {
257 #[cfg(feature = "c_openssl_3_0")]
258 use crate::util::c_openssl::ffi::ssl::SSL_get1_peer_certificate;
259 #[cfg(feature = "c_openssl_1_1")]
260 use crate::util::c_openssl::ffi::ssl::SSL_get_peer_certificate;
261
262 let certificate = unsafe {
263 #[cfg(feature = "c_openssl_3_0")]
264 {
265 SSL_get1_peer_certificate(ssl)
266 }
267 #[cfg(feature = "c_openssl_1_1")]
268 {
269 SSL_get_peer_certificate(ssl)
270 }
271 };
272 if certificate.is_null() {
273 return Err(SslError {
274 code: SslErrorCode::SSL,
275 internal: Some(InternalError::Ssl(ErrorStack::get())),
276 });
277 }
278
279 let size_1 = unsafe { i2d_X509_PUBKEY(X509_get_X509_PUBKEY(certificate), ptr::null_mut()) };
280 if size_1 < 1 {
281 unsafe { X509_free(certificate) };
282 return Err(SslError {
283 code: SslErrorCode::SSL,
284 internal: Some(InternalError::Ssl(ErrorStack::get())),
285 });
286 }
287 let key = vec![0u8; size_1 as usize];
288 let size_2 = unsafe { i2d_X509_PUBKEY(X509_get_X509_PUBKEY(certificate), &mut key.as_ptr()) };
289
290 if size_1 != size_2 || size_2 <= 0 {
291 unsafe { X509_free(certificate) };
292 return Err(SslError {
293 code: SslErrorCode::SSL,
294 internal: Some(InternalError::Ssl(ErrorStack::get())),
295 });
296 }
297
298 // sha256 length.
299 let mut digest = [0u8; 32];
300 unsafe { sha256_digest(key.as_slice(), size_2, &mut digest)? }
301
302 compare_pinned_digest(&digest, pinned_key.as_bytes(), certificate)
303 }
304
compare_pinned_digest( digest: &[u8], pinned_key: &[u8], certificate: *mut C_X509, ) -> Result<(), SslError>305 fn compare_pinned_digest(
306 digest: &[u8],
307 pinned_key: &[u8],
308 certificate: *mut C_X509,
309 ) -> Result<(), SslError> {
310 let base64_digest = encode(digest);
311 let mut user_bytes = pinned_key;
312
313 let mut begin;
314 let mut end;
315 let prefix = b"sha256//";
316 let suffix = b";sha256//";
317 while !user_bytes.is_empty() {
318 begin = match user_bytes
319 .windows(prefix.len())
320 .position(|window| window == prefix)
321 {
322 None => {
323 break;
324 }
325 Some(index) => index + 8,
326 };
327 end = match user_bytes
328 .windows(suffix.len())
329 .position(|window| window == suffix)
330 {
331 None => user_bytes.len(),
332 Some(index) => index,
333 };
334
335 let bytes = &user_bytes[begin..end];
336 if bytes.eq(base64_digest.as_slice()) {
337 unsafe { X509_free(certificate) };
338 return Ok(());
339 }
340
341 if end != user_bytes.len() {
342 user_bytes = &user_bytes[end + 1..];
343 } else {
344 user_bytes = &user_bytes[end..];
345 }
346 }
347
348 unsafe { X509_free(certificate) };
349 Err(SslError {
350 code: SslErrorCode::SSL,
351 internal: Some(InternalError::User(VerifyError::from_msg(
352 PubKeyPinning,
353 "Pinned public key verification failed.",
354 ))),
355 })
356 }
357