From: Ben Pfaff Date: Fri, 11 Jul 2025 17:21:47 +0000 (-0700) Subject: continue crypto work X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Frust;p=pspp continue crypto work --- diff --git a/rust/pspp/src/crypto/mod.rs b/rust/pspp/src/crypto/mod.rs index e0af37251c..c750c42c4f 100644 --- a/rust/pspp/src/crypto/mod.rs +++ b/rust/pspp/src/crypto/mod.rs @@ -15,7 +15,7 @@ use cmac::{Cmac, Mac}; use smallvec::SmallVec; use std::{ fmt::Debug, - io::{Error as IoError, ErrorKind, Read}, + io::{Error as IoError, ErrorKind, Read, Seek}, }; use thiserror::Error as ThisError; @@ -36,6 +36,10 @@ pub enum Error { #[error("Not an encrypted file")] NotEncrypted, + /// Encrypted file has invalid length. + #[error("Encrypted file has invalid length {0} (expected 4 more than a multiple of 16).")] + InvalidLength(u64), + /// Unknown file type. #[error("Unknown file type {0:?}.")] UnknownFileType(String), @@ -64,7 +68,17 @@ struct EncryptedHeader { pub struct EncryptedFile { reader: R, file_type: FileType, - cipher_text: [u8; 16], + + /// Length of the ciphertext (excluding the 36-byte header). + length: u64, + + /// First block of ciphertext, for verifying that any password the user + /// tries is correct. + first_block: [u8; 16], + + /// Last block of ciphertext, for checking padding and determining the + /// plaintext length. + last_block: [u8; 16], } /// Type of encrypted file. @@ -82,7 +96,7 @@ pub enum FileType { impl EncryptedFile where - R: Read, + R: Read + Seek, { /// Opens `reader` as an encrypted file. /// @@ -108,12 +122,20 @@ where )) } }; - let mut cipher_text = [0; 16]; - reader.read_exact(&mut cipher_text)?; + let mut first_block = [0; 16]; + reader.read_exact(&mut first_block)?; + let length = reader.seek(std::io::SeekFrom::End(-16))? + 16; + if length < 36 + 16 || (length - 36) % 16 != 0 { + return Err(Error::InvalidLength(length + 36)); + } + let mut last_block = [0; 16]; + reader.read_exact(&mut last_block)?; Ok(Self { reader, file_type, - cipher_text, + length, + first_block, + last_block, }) } @@ -186,7 +208,7 @@ where // Decrypt first block to verify password. let mut out = [0; 16]; aes.decrypt_block_b2b( - &GenericArray::from_slice(&self.cipher_text), + &GenericArray::from_slice(&self.first_block), GenericArray::from_mut_slice(&mut out), ); static MAGIC: &[&[u8]] = &[ @@ -198,11 +220,17 @@ where if !MAGIC.iter().any(|magic| out.starts_with(*magic)) { return Err(self); } + + // Decrypt last block to check padding and get final length. + let Some(padding_length) = parse_padding(&self.last_block) else { + return Err(self); + }; + Ok(EncryptedReader::new( self.reader, aes, - self.cipher_text, self.file_type, + length - 36 - padding_length, )) } @@ -212,6 +240,15 @@ where } } +fn parse_padding(block: &[u8; 16]) -> Option { + let pad = block[15] as usize; + if (1..=16).contains(&pad) && block[16 - pad..].iter().all(|b| *b == pad as u8) { + Some(pad) + } else { + None + } +} + impl Debug for EncryptedFile where R: Read, @@ -226,29 +263,42 @@ where /// This implements [Read](std::io::Read) for an SPSS encrypted file. Obtain by /// [EncryptedFile::new] followed by [EncryptedFile::unlock]. pub struct EncryptedReader { + /// Underlying reader. reader: R, - eof: bool, + + /// AES-256 decryption key. aes: Aes256Dec, - plain_text: [u8; 16], + + /// Type of file. + file_type: FileType, + + /// Plaintext file length (not including the file header or padding). + length: u64, + + /// Plaintext data buffer. + buffer: Box<[u8; 4096]>, + + /// Plaintext offset of the byte in `buffer[0]`. A multiple of 16 less than + /// or equal to `length`. + start: u64, + + /// Number of bytes in buffer (`0 <= head <= 4096`). head: usize, + + /// Offset in buffer of the next byte to read (`head <= tail`). tail: usize, - cipher_text: [u8; 16], - file_type: FileType, } impl EncryptedReader { - fn new(reader: R, aes: Aes256Dec, cipher_text: [u8; 16], file_type: FileType) -> Self { + fn new(reader: R, aes: Aes256Dec, file_type: FileType, length: u64) -> Self { Self { reader, - eof: false, aes, - plain_text: [0; 16], - head: 0, - tail: 0, - cipher_text, file_type, + length, } } + fn decrypt(&self, cipher_text: &[u8; 16]) -> [u8; 16] { let mut out = [0; 16]; self.aes.decrypt_block_b2b( @@ -260,7 +310,7 @@ impl EncryptedReader { fn read_buffer(&mut self, buf: &mut [u8]) -> Result { let n = buf.len().min(self.head - self.tail); - buf[..n].copy_from_slice(&self.plain_text[self.tail..n + self.tail]); + buf[..n].copy_from_slice(&self.buffer[self.tail..n + self.tail]); self.tail += n; Ok(n) } @@ -298,30 +348,30 @@ impl EncryptedReader { } } -fn parse_padding(block: &[u8; 16]) -> std::io::Result { - let pad = block[15] as usize; - if (1..=16).contains(&pad) && block[16 - pad..].iter().all(|b| *b == pad as u8) { - Ok(pad) - } else { - Err(IoError::other(Error::InvalidPadding)) - } -} - impl Read for EncryptedReader where R: Read, { fn read(&mut self, buf: &mut [u8]) -> Result { if self.head != self.tail { - return self.read_buffer(buf); - } else if self.eof { + self.read_buffer(buf) + } else if self.start + self.head as u64 >= self.length { Ok(0) } else { - let retval = self.read_inner(buf); - if let Ok(0) | Err(_) = retval { - self.eof = true - }; - retval + self.start += self.head as u64; + self.head = 0; + self.tail = 0; + let n = self + .buffer + .len() + .min((self.length - self.start).next_multiple_of(16) as usize); + self.reader.read_exact(&mut self.buffer[..n])?; + for offset in (0..n).step_by(16) { + self.aes.decrypt_block(GenericArray::from_mut_slice( + &mut self.buffer[offset..offset + 16], + )); + } + self.tail = } } }