1
use std::{
2
    collections::HashMap,
3
    fmt::{Debug, Display},
4
};
5

            
6
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
7

            
8
use super::{btree::BTreeEntry, read_chunk, BinarySerialization, PagedWriter};
9
use crate::{
10
    chunk_cache::CacheEntry,
11
    error::Error,
12
    io::File,
13
    tree::{btree::NodeInclusion, key_entry::PositionIndex},
14
    vault::AnyVault,
15
    AbortError, ArcBytes, ChunkCache, ErrorKind,
16
};
17

            
18
/// An interior B-Tree node. Does not contain values directly, and instead
19
/// points to a node located on-disk elsewhere.
20
2078181
#[derive(Clone, Debug)]
21
pub struct Interior<Index, ReducedIndex> {
22
    /// The key with the highest sort value within.
23
    pub key: ArcBytes<'static>,
24
    /// The location of the node.
25
    pub position: Pointer<Index, ReducedIndex>,
26
    /// The reduced statistics.
27
    pub stats: ReducedIndex,
28
}
29

            
30
/// A pointer to a location on-disk. May also contain the node already loaded.
31
2078181
#[derive(Clone, Debug)]
32
pub enum Pointer<Index, ReducedIndex> {
33
    /// The position on-disk of the node.
34
    OnDisk(u64),
35
    /// An in-memory node that may have previously been saved on-disk.
36
    Loaded {
37
        /// The position on-disk of the node, if it was previously saved.
38
        previous_location: Option<u64>,
39
        /// The loaded B-Tree entry.
40
        entry: Box<BTreeEntry<Index, ReducedIndex>>,
41
    },
42
}
43

            
44
impl<
45
        Index: BinarySerialization + Debug + Clone + 'static,
46
        ReducedIndex: BinarySerialization + Debug + Clone + 'static,
47
    > Pointer<Index, ReducedIndex>
48
{
49
    /// Attempts to load the node from disk. If the node is already loaded, this
50
    /// function does nothing.
51
    #[allow(clippy::missing_panics_doc)] // Currently the only panic is if the types don't match, which shouldn't happen due to these nodes always being accessed through a root.
52
1700935
    pub fn load(
53
1700935
        &mut self,
54
1700935
        file: &mut dyn File,
55
1700935
        validate_crc: bool,
56
1700935
        vault: Option<&dyn AnyVault>,
57
1700935
        cache: Option<&ChunkCache>,
58
1700935
        current_order: Option<usize>,
59
1700935
    ) -> Result<(), Error> {
60
1700935
        match self {
61
999201
            Self::OnDisk(position) => {
62
999201
                let entry = match read_chunk(*position, validate_crc, file, vault, cache)? {
63
929419
                    CacheEntry::ArcBytes(mut buffer) => {
64
929419
                        // It's worthless to store this node in the cache
65
929419
                        // because if we mutate, we'll be rewritten.
66
929419
                        Box::new(BTreeEntry::deserialize_from(&mut buffer, current_order)?)
67
                    }
68
69782
                    CacheEntry::Decoded(node) => node
69
69782
                        .as_ref()
70
69782
                        .as_any()
71
69782
                        .downcast_ref::<Box<BTreeEntry<Index, ReducedIndex>>>()
72
69782
                        .unwrap()
73
69782
                        .clone(),
74
                };
75
999201
                *self = Self::Loaded {
76
999201
                    entry,
77
999201
                    previous_location: Some(*position),
78
999201
                };
79
            }
80
701734
            Self::Loaded { .. } => {}
81
        }
82
1700935
        Ok(())
83
1700935
    }
84

            
85
    /// Returns the previously-[`load()`ed](Self::load) entry.
86
161144
    pub fn get(&mut self) -> Option<&BTreeEntry<Index, ReducedIndex>> {
87
161144
        match self {
88
            Self::OnDisk(_) => None,
89
161144
            Self::Loaded { entry, .. } => Some(entry),
90
        }
91
161144
    }
92

            
93
    /// Returns the previously-[`load()`ed](Self::load) entry as a mutable reference.
94
2653385
    pub fn get_mut(&mut self) -> Option<&mut BTreeEntry<Index, ReducedIndex>> {
95
2653385
        match self {
96
            Self::OnDisk(_) => None,
97
2653385
            Self::Loaded { entry, .. } => Some(entry.as_mut()),
98
        }
99
2653385
    }
100

            
101
    /// Returns the position on-disk of the node being pointed at, if the node
102
    /// has been saved before.
103
    #[must_use]
104
325324
    pub fn position(&self) -> Option<u64> {
105
325324
        match self {
106
            Self::OnDisk(location) => Some(*location),
107
            Self::Loaded {
108
325324
                previous_location, ..
109
325324
            } => *previous_location,
110
        }
111
325324
    }
112

            
113
    /// Loads the pointed at node, if necessary, and invokes `callback` with the
114
    /// loaded node. This is useful in situations where the node isn't needed to
115
    /// be accessed mutably.
116
    #[allow(clippy::missing_panics_doc)]
117
1089104
    pub fn map_loaded_entry<
118
1089104
        Output,
119
1089104
        CallerError: Display + Debug,
120
1089104
        Cb: FnOnce(
121
1089104
            &BTreeEntry<Index, ReducedIndex>,
122
1089104
            &mut dyn File,
123
1089104
        ) -> Result<Output, AbortError<CallerError>>,
124
1089104
    >(
125
1089104
        &self,
126
1089104
        file: &mut dyn File,
127
1089104
        vault: Option<&dyn AnyVault>,
128
1089104
        cache: Option<&ChunkCache>,
129
1089104
        current_order: Option<usize>,
130
1089104
        callback: Cb,
131
1089104
    ) -> Result<Output, AbortError<CallerError>> {
132
1089104
        match self {
133
1089104
            Self::OnDisk(position) => match read_chunk(*position, false, file, vault, cache)? {
134
992584
                CacheEntry::ArcBytes(mut buffer) => {
135
992584
                    let decoded = BTreeEntry::deserialize_from(&mut buffer, current_order)?;
136

            
137
992584
                    let result = callback(&decoded, file);
138
992584
                    if let (Some(cache), Some(file_id)) = (cache, file.id().id()) {
139
                        cache.replace_with_decoded(file_id, *position, Box::new(decoded));
140
992584
                    }
141
992584
                    result
142
                }
143
96520
                CacheEntry::Decoded(value) => {
144
96520
                    let entry = value
145
96520
                        .as_ref()
146
96520
                        .as_any()
147
96520
                        .downcast_ref::<Box<BTreeEntry<Index, ReducedIndex>>>()
148
96520
                        .unwrap();
149
96520
                    callback(entry, file)
150
                }
151
            },
152
            Self::Loaded { entry, .. } => callback(entry, file),
153
        }
154
1089104
    }
155
}
156

            
157
impl<
158
        Index: Clone + PositionIndex + BinarySerialization + Debug + 'static,
159
        ReducedIndex: Clone + BinarySerialization + Debug + 'static,
160
    > Interior<Index, ReducedIndex>
161
{
162
    /// Returns a new instance
163
28545
    pub fn new<Reducer: super::Reducer<Index, ReducedIndex>>(
164
28545
        entry: BTreeEntry<Index, ReducedIndex>,
165
28545
        reducer: &Reducer,
166
28545
    ) -> Self {
167
28545
        let key = entry.max_key().clone();
168
28545

            
169
28545
        Self {
170
28545
            key,
171
28545
            stats: entry.stats(reducer),
172
28545
            position: Pointer::Loaded {
173
28545
                previous_location: None,
174
28545
                entry: Box::new(entry),
175
28545
            },
176
28545
        }
177
28545
    }
178

            
179
    #[allow(clippy::too_many_arguments)]
180
    pub(crate) fn copy_data_to<Callback>(
181
        &mut self,
182
        include_nodes: NodeInclusion,
183
        file: &mut dyn File,
184
        copied_chunks: &mut HashMap<u64, u64>,
185
        writer: &mut PagedWriter<'_>,
186
        vault: Option<&dyn AnyVault>,
187
        scratch: &mut Vec<u8>,
188
        index_callback: &mut Callback,
189
    ) -> Result<bool, Error>
190
    where
191
        Callback: FnMut(
192
            &ArcBytes<'static>,
193
            &mut Index,
194
            &mut dyn File,
195
            &mut HashMap<u64, u64>,
196
            &mut PagedWriter<'_>,
197
            Option<&dyn AnyVault>,
198
        ) -> Result<bool, Error>,
199
    {
200
650648
        self.position.load(file, true, vault, None, None)?;
201
650648
        let node = self.position.get_mut().unwrap();
202
650648
        let mut any_data_copied = node.copy_data_to(
203
650648
            include_nodes,
204
650648
            file,
205
650648
            copied_chunks,
206
650648
            writer,
207
650648
            vault,
208
650648
            scratch,
209
650648
            index_callback,
210
650648
        )?;
211

            
212
        // Serialize if we are supposed to
213
650648
        let position = if include_nodes.should_include() {
214
325324
            any_data_copied = true;
215
325324
            scratch.clear();
216
325324
            node.serialize_to(scratch, writer)?;
217
325324
            Some(writer.write_chunk(scratch)?)
218
        } else {
219
325324
            self.position.position()
220
        };
221

            
222
        // Remove the node from memory to save RAM during the compaction process.
223
650648
        if let Some(position) = position {
224
650648
            self.position = Pointer::OnDisk(position);
225
650648
        }
226

            
227
650648
        Ok(any_data_copied)
228
650648
    }
229
}
230

            
231
impl<
232
        Index: Clone + BinarySerialization + Debug + 'static,
233
        ReducedIndex: Clone + BinarySerialization + Debug + 'static,
234
    > BinarySerialization for Interior<Index, ReducedIndex>
235
{
236
4364602
    fn serialize_to(
237
4364602
        &mut self,
238
4364602
        writer: &mut Vec<u8>,
239
4364602
        paged_writer: &mut PagedWriter<'_>,
240
4364602
    ) -> Result<usize, Error> {
241
4364602
        let mut pointer = Pointer::OnDisk(0);
242
4364602
        std::mem::swap(&mut pointer, &mut self.position);
243
4364602
        let location_on_disk = match pointer {
244
3989610
            Pointer::OnDisk(position) => position,
245
            Pointer::Loaded {
246
374992
                mut entry,
247
374992
                previous_location,
248
374992
            } => match (entry.dirty, previous_location) {
249
                // Serialize if dirty, or if this node hasn't been on-disk before.
250
                (true, _) | (_, None) => {
251
347318
                    entry.dirty = false;
252
347318
                    let old_writer_length = writer.len();
253
347318
                    entry.serialize_to(writer, paged_writer)?;
254
347318
                    let position =
255
347318
                        paged_writer.write_chunk(&writer[old_writer_length..writer.len()])?;
256
347318
                    writer.truncate(old_writer_length);
257
184897
                    if let (Some(cache), Some(file_id)) =
258
347318
                        (paged_writer.cache, paged_writer.id().id())
259
184897
                    {
260
184897
                        cache.replace_with_decoded(file_id, position, entry);
261
184897
                    }
262
347318
                    position
263
                }
264
27674
                (false, Some(position)) => position,
265
            },
266
        };
267
4364602
        self.position = Pointer::OnDisk(location_on_disk);
268
4364602
        let mut bytes_written = 0;
269
        // Write the key
270
4364602
        let key_len = u16::try_from(self.key.len()).map_err(|_| ErrorKind::KeyTooLarge)?;
271
4364602
        writer.write_u16::<BigEndian>(key_len)?;
272
4364602
        writer.extend_from_slice(&self.key);
273
4364602
        bytes_written += 2 + key_len as usize;
274
4364602

            
275
4364602
        writer.write_u64::<BigEndian>(location_on_disk)?;
276
4364602
        bytes_written += 8;
277
4364602

            
278
4364602
        bytes_written += self.stats.serialize_to(writer, paged_writer)?;
279

            
280
4364602
        Ok(bytes_written)
281
4364602
    }
282

            
283
5601146
    fn deserialize_from(
284
5601146
        reader: &mut ArcBytes<'_>,
285
5601146
        current_order: Option<usize>,
286
5601146
    ) -> Result<Self, Error> {
287
5601146
        let key_len = reader.read_u16::<BigEndian>()? as usize;
288
5601146
        if key_len > reader.len() {
289
            return Err(Error::data_integrity(format!(
290
                "key length {} found but only {} bytes remaining",
291
                key_len,
292
                reader.len()
293
            )));
294
5601146
        }
295
5601146
        let key = reader.read_bytes(key_len)?.into_owned();
296

            
297
5601146
        let position = reader.read_u64::<BigEndian>()?;
298
5601146
        let stats = ReducedIndex::deserialize_from(reader, current_order)?;
299

            
300
5601146
        Ok(Self {
301
5601146
            key,
302
5601146
            position: Pointer::OnDisk(position),
303
5601146
            stats,
304
5601146
        })
305
5601146
    }
306
}