1
use std::{
2
    io::ErrorKind,
3
    ops::Deref,
4
    path::{Path, PathBuf},
5
};
6

            
7
use crate::Vault;
8

            
9
// TODO this should be shared between nebari and bonsaidb-core.
10

            
11
pub struct TestDirectory(pub PathBuf);
12

            
13
impl TestDirectory {
14
57
    pub fn new<S: AsRef<Path>>(name: S) -> Self {
15
57
        let path = std::env::temp_dir().join(name);
16
57
        if path.exists() {
17
            std::fs::remove_dir_all(&path).expect("error clearing temporary directory");
18
57
        }
19
57
        Self(path)
20
57
    }
21
}
22

            
23
impl Drop for TestDirectory {
24
    fn drop(&mut self) {
25
57
        if let Err(err) = std::fs::remove_dir_all(&self.0) {
26
            if err.kind() != ErrorKind::NotFound {
27
                eprintln!("Failed to clean up temporary folder: {:?}", err);
28
            }
29
57
        }
30
57
    }
31
}
32

            
33
impl AsRef<Path> for TestDirectory {
34
57
    fn as_ref(&self) -> &Path {
35
57
        &self.0
36
57
    }
37
}
38

            
39
impl Deref for TestDirectory {
40
    type Target = PathBuf;
41

            
42
57
    fn deref(&self) -> &Self::Target {
43
57
        &self.0
44
57
    }
45
}
46

            
47
#[derive(Debug)]
48
pub struct RotatorVault {
49
    rotation_amount: u8,
50
}
51

            
52
impl RotatorVault {
53
12
    pub const fn new(rotation_amount: u8) -> Self {
54
12
        Self { rotation_amount }
55
12
    }
56
}
57

            
58
impl Vault for RotatorVault {
59
    type Error = NotEncrypted;
60
12008
    fn encrypt(&self, payload: &[u8]) -> Result<Vec<u8>, NotEncrypted> {
61
12008
        let mut output = Vec::with_capacity(payload.len() + 4);
62
12008
        output.extend(b"rotv");
63
5236986
        output.extend(payload.iter().map(|c| c.wrapping_add(self.rotation_amount)));
64
12008
        Ok(output)
65
12008
    }
66

            
67
75786
    fn decrypt(&self, payload: &[u8]) -> Result<Vec<u8>, NotEncrypted> {
68
75786
        if payload.len() < 4 {
69
            return Err(NotEncrypted);
70
75786
        }
71
75786
        let (header, payload) = payload.split_at(4);
72
75786
        if header != b"rotv" {
73
4
            return Err(NotEncrypted);
74
75782
        }
75
75782

            
76
75782
        Ok(payload
77
75782
            .iter()
78
347250568
            .map(|c| c.wrapping_sub(self.rotation_amount))
79
75782
            .collect())
80
75786
    }
81
}
82

            
83
#[derive(thiserror::Error, Debug)]
84
#[error("not an encrypted payload")]
85
pub struct NotEncrypted;