1
use std::{
2
    borrow::Cow,
3
    cmp::Ordering,
4
    fs::OpenOptions,
5
    io::{SeekFrom, Write},
6
    ops::{Bound, RangeBounds},
7
    path::Path,
8
    sync::Arc,
9
};
10

            
11
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
12

            
13
use super::{State, TransactionHandle};
14
use crate::{
15
    error::Error,
16
    io::{File, FileManager, FileOp, ManagedFile, OpenableFile, OperableFile},
17
    transaction::TransactionId,
18
    vault::AnyVault,
19
    ArcBytes, Context, ErrorKind,
20
};
21

            
22
const PAGE_SIZE: usize = 1024;
23

            
24
/// A transaction log that records changes for one or more trees.
25
pub struct TransactionLog<File: ManagedFile> {
26
    vault: Option<Arc<dyn AnyVault>>,
27
    state: State,
28
    log: <File::Manager as FileManager>::FileHandle,
29
}
30

            
31
impl<File: ManagedFile> TransactionLog<File> {
32
    /// Opens a transaction log for reading.
33
405
    pub fn read(
34
405
        log_path: &Path,
35
405
        state: State,
36
405
        context: Context<File::Manager>,
37
405
    ) -> Result<Self, Error> {
38
405
        let log = context.file_manager.read(log_path)?;
39
405
        Ok(Self {
40
405
            vault: context.vault,
41
405
            state,
42
405
            log,
43
405
        })
44
405
    }
45

            
46
    /// Opens a transaction log for writing.
47
6033
    pub fn open(
48
6033
        log_path: &Path,
49
6033
        state: State,
50
6033
        context: Context<File::Manager>,
51
6033
    ) -> Result<Self, Error> {
52
6033
        let log = context.file_manager.append(log_path)?;
53
6033
        Ok(Self {
54
6033
            vault: context.vault,
55
6033
            state,
56
6033
            log,
57
6033
        })
58
6033
    }
59

            
60
    /// Returns the total size of the transaction log file.
61
    pub fn total_size(&self) -> u64 {
62
        self.state.len()
63
    }
64

            
65
    /// Initializes `state` to contain the information about the transaction log
66
    /// located at `log_path`.
67
6037
    pub fn initialize_state(state: &State, context: &Context<File::Manager>) -> Result<(), Error> {
68
6037
        let mut log_length = match context.file_manager.file_length(state.path()) {
69
6007
            Ok(length) => length,
70
            Err(Error {
71
30
                kind: ErrorKind::Io(err),
72
                ..
73
30
            }) if err.kind() == std::io::ErrorKind::NotFound => 0,
74
            Err(other) => return Err(other),
75
        };
76
6037
        if log_length == 0 {
77
31
            state.initialize(TransactionId(0), 0);
78
31
            return Ok(());
79
6006
        }
80
6006

            
81
6006
        let excess_length = log_length % PAGE_SIZE as u64;
82
6006
        if excess_length > 0 {
83
            // Truncate the file to the proper page size. This should only happen in a recovery situation.
84
            eprintln!(
85
                "Transaction log has {} extra bytes. Truncating.",
86
                excess_length
87
            );
88
            let file = OpenOptions::new()
89
                .append(true)
90
                .write(true)
91
                .open(state.path())?;
92
            log_length -= excess_length;
93
            file.set_len(log_length)?;
94
            file.sync_all()?;
95
6006
        }
96

            
97
6006
        let mut file = context.file_manager.read(state.path())?;
98
6006
        file.execute(StateInitializer {
99
6006
            state,
100
6006
            log_length,
101
6006
            vault: context.vault(),
102
6006
        })
103
6037
    }
104

            
105
    /// Logs one or more transactions. After this call returns, the transaction
106
    /// log is guaranteed to be fully written to disk.
107
    ///
108
    /// # Errors
109
    ///
110
    /// Returns [`ErrorKind::TransactionPushedOutOfOrder`] if `handles` is out of
111
    /// order, or if any handle contains an id older than one already written to
112
    /// the log.
113
53059
    pub fn push(&mut self, handles: Vec<LogEntry<'static>>) -> Result<(), Error> {
114
53059
        self.log.execute(LogWriter {
115
53059
            state: self.state.clone(),
116
53059
            vault: self.vault.clone(),
117
53059
            transactions: handles,
118
53059
        })
119
53059
    }
120

            
121
    /// Returns the executed transaction with the id provided. Returns None if not found.
122
    pub fn get(&mut self, id: TransactionId) -> Result<Option<LogEntry<'static>>, Error> {
123
16012
        match self.log.execute(EntryFetcher {
124
16012
            id,
125
16012
            state: &self.state,
126
16012
            vault: self.vault.as_deref(),
127
16012
        })? {
128
15697
            ScanResult::Found { entry, .. } => Ok(Some(entry)),
129
315
            ScanResult::NotFound { .. } => Ok(None),
130
        }
131
16012
    }
132

            
133
    /// Logs one or more transactions. After this call returns, the transaction
134
    /// log is guaranteed to be fully written to disk.
135
417
    pub fn scan<Callback: FnMut(LogEntry<'static>) -> bool>(
136
417
        &mut self,
137
417
        ids: impl RangeBounds<TransactionId>,
138
417
        callback: Callback,
139
417
    ) -> Result<(), Error> {
140
417
        self.log.execute(EntryScanner {
141
417
            ids,
142
417
            callback,
143
417
            state: &self.state,
144
417
            vault: self.vault.as_deref(),
145
417
        })
146
417
    }
147

            
148
    /// Closes the transaction log.
149
6000
    pub fn close(self) -> Result<(), Error> {
150
6000
        self.log.close()
151
6000
    }
152

            
153
    /// Begins a new transaction, exclusively locking `trees`.
154
16012
    pub fn new_transaction<
155
16012
        'a,
156
16012
        I: IntoIterator<Item = &'a [u8], IntoIter = II>,
157
16012
        II: ExactSizeIterator<Item = &'a [u8]>,
158
16012
    >(
159
16012
        &self,
160
16012
        trees: I,
161
16012
    ) -> TransactionHandle {
162
16012
        self.state.new_transaction(trees)
163
16012
    }
164

            
165
    /// Returns the current state of the log.
166
16026
    pub fn state(&self) -> &State {
167
16026
        &self.state
168
16026
    }
169
}
170

            
171
struct StateInitializer<'a> {
172
    state: &'a State,
173
    log_length: u64,
174
    vault: Option<&'a dyn AnyVault>,
175
}
176

            
177
impl<'a> FileOp<Result<(), Error>> for StateInitializer<'a> {
178
6006
    fn execute(self, log: &mut dyn File) -> Result<(), Error> {
179
6006
        // Scan back block by block until we find a page header with a value of 1.
180
6006
        let block_start = self.log_length - PAGE_SIZE as u64;
181
6006
        let mut scratch_buffer = Vec::new();
182
6006
        scratch_buffer.reserve(PAGE_SIZE);
183
6006
        scratch_buffer.resize(4, 0);
184

            
185
6002
        let last_transaction =
186
6006
            match scan_for_transaction(log, &mut scratch_buffer, block_start, false, self.vault)? {
187
6002
                ScanResult::Found { entry, .. } => entry,
188
                ScanResult::NotFound { .. } => {
189
                    return Err(Error::data_integrity(
190
                        "No entries found in an existing transaction log",
191
                    ))
192
                }
193
            };
194

            
195
6002
        self.state.initialize(last_transaction.id, self.log_length);
196
6002
        Ok(())
197
6006
    }
198
}
199

            
200
pub enum ScanResult {
201
    Found {
202
        entry: LogEntry<'static>,
203
        position: u64,
204
        length: u64,
205
    },
206
    NotFound {
207
        nearest_position: u64,
208
    },
209
}
210

            
211
444162
fn scan_for_transaction(
212
444162
    log: &mut dyn File,
213
444162
    scratch_buffer: &mut Vec<u8>,
214
444162
    mut block_start: u64,
215
444162
    scan_forward: bool,
216
444162
    vault: Option<&dyn AnyVault>,
217
444162
) -> Result<ScanResult, Error> {
218
444162
    if scratch_buffer.len() < 4 {
219
16426
        scratch_buffer.resize(4, 0);
220
427736
    }
221
444162
    let file_length = log.length()?;
222
996068
    Ok(loop {
223
996068
        if block_start >= file_length {
224
400
            return Ok(ScanResult::NotFound {
225
400
                nearest_position: block_start,
226
400
            });
227
995668
        }
228
995668
        log.seek(SeekFrom::Start(block_start))?;
229
        // Read the page header
230
995668
        log.read_exact(&mut scratch_buffer[0..4])?;
231
        #[allow(clippy::match_on_vec_items)]
232
995668
        match scratch_buffer[0] {
233
            0 => {
234
551906
                if block_start == 0 {
235
                    break ScanResult::NotFound {
236
                        nearest_position: 0,
237
                    };
238
551906
                }
239
551906
                if scan_forward {
240
                    block_start += PAGE_SIZE as u64;
241
551906
                } else {
242
551906
                    block_start -= PAGE_SIZE as u64;
243
551906
                }
244
551906
                continue;
245
            }
246
            1 => {
247
                // The length is the next 3 bytes.
248
443762
                let length = (scratch_buffer[1] as usize) << 16
249
443762
                    | (scratch_buffer[2] as usize) << 8
250
443762
                    | scratch_buffer[3] as usize;
251
443762
                scratch_buffer.resize(length, 0);
252
443762
                let mut initial_page = true;
253
443762
                let mut bytes_to_read = length;
254
443762
                let mut offset = 0;
255
1881492
                while bytes_to_read > 0 {
256
1437730
                    let page_header_length = if initial_page {
257
                        // The initial page has 4 bytes at the start, which we've already read.
258
443762
                        initial_page = false;
259
443762
                        4
260
                    } else {
261
                        // Subsequent pages have a 0 byte at the start of the
262
                        // page, denoting that it's not a valid page header. We
263
                        // need to skip that byte, so that the read call reads
264
                        // the stored data, not the header byte.
265
993968
                        log.seek(SeekFrom::Current(1))?;
266
993968
                        1
267
                    };
268

            
269
1437730
                    let page_length = (PAGE_SIZE - page_header_length).min(length - offset);
270
1437730
                    log.read_exact(&mut scratch_buffer[offset..offset + page_length])?;
271
1437730
                    offset += page_length;
272
1437730
                    bytes_to_read -= page_length;
273
                }
274

            
275
443762
                let payload = &scratch_buffer[0..length];
276
443762
                let decrypted = match &vault {
277
75780
                    Some(vault) => Cow::Owned(vault.decrypt(payload)?),
278
367982
                    None => Cow::Borrowed(payload),
279
                };
280
443758
                let entry = LogEntry::deserialize(&decrypted)
281
443758
                    .map_err(Error::data_integrity)?
282
443758
                    .into_owned();
283
443758
                break ScanResult::Found {
284
443758
                    entry,
285
443758
                    position: block_start,
286
443758
                    length: length as u64,
287
443758
                };
288
            }
289
            _ => unreachable!("corrupt transaction log"),
290
        }
291
    })
292
444162
}
293

            
294
#[allow(clippy::redundant_pub_crate)]
295
pub(crate) struct EntryFetcher<'a> {
296
    pub state: &'a State,
297
    pub id: TransactionId,
298
    pub vault: Option<&'a dyn AnyVault>,
299
}
300

            
301
impl<'a> FileOp<Result<ScanResult, Error>> for EntryFetcher<'a> {
302
16020
    fn execute(self, log: &mut dyn File) -> Result<ScanResult, Error> {
303
16020
        let mut scratch = Vec::with_capacity(PAGE_SIZE);
304
16020
        fetch_entry(log, &mut scratch, self.state, self.id, self.vault)
305
16020
    }
306
}
307

            
308
16431
fn fetch_entry(
309
16431
    log: &mut dyn File,
310
16431
    scratch_buffer: &mut Vec<u8>,
311
16431
    state: &State,
312
16431
    id: TransactionId,
313
16431
    vault: Option<&dyn AnyVault>,
314
16431
) -> Result<ScanResult, Error> {
315
16431
    if !id.valid() {
316
6
        return Ok(ScanResult::NotFound {
317
6
            nearest_position: 0,
318
6
        });
319
16425
    }
320
16425

            
321
16425
    let mut upper_id = state.next_transaction_id();
322
16425
    let mut upper_location = state.len();
323
16425
    if upper_id <= id {
324
5
        return Ok(ScanResult::NotFound {
325
5
            nearest_position: upper_location,
326
5
        });
327
16420
    }
328
16420
    let mut lower_id = None;
329
16420
    let mut lower_location = None;
330
    loop {
331
422977
        let guessed_location = if let Some(page) =
332
423280
            guess_page(id, lower_location, lower_id, upper_location, upper_id)
333
        {
334
422977
            page
335
        } else {
336
303
            return Ok(ScanResult::NotFound {
337
303
                nearest_position: upper_location,
338
303
            });
339
        };
340
422977
        if guessed_location == upper_location {
341
1
            return Ok(ScanResult::NotFound {
342
1
                nearest_position: upper_location,
343
1
            });
344
422976
        }
345
422976

            
346
422976
        // load the transaction at this location
347
422976
        #[allow(clippy::cast_possible_wrap)]
348
422976
        let scan_forward = guessed_location >= upper_location;
349
422976
        match scan_for_transaction(log, scratch_buffer, guessed_location, scan_forward, vault)? {
350
            ScanResult::Found {
351
422976
                entry,
352
422976
                position,
353
422976
                length,
354
422976
            } => {
355
422976
                state.note_transaction_id_status(entry.id, Some(position));
356
422976
                match entry.id.cmp(&id) {
357
                    Ordering::Less => {
358
10716
                        if lower_id.is_none() || entry.id > lower_id.unwrap() {
359
10710
                            lower_id = Some(entry.id);
360
10710
                            lower_location = Some(position);
361
10710
                        } else {
362
6
                            return Ok(ScanResult::NotFound {
363
6
                                nearest_position: position,
364
6
                            });
365
                        }
366
                    }
367
                    Ordering::Equal => {
368
16110
                        return Ok(ScanResult::Found {
369
16110
                            entry,
370
16110
                            position,
371
16110
                            length,
372
16110
                        });
373
                    }
374
                    Ordering::Greater => {
375
396150
                        if entry.id < upper_id {
376
396150
                            upper_id = entry.id;
377
396150
                            upper_location = position;
378
396150
                        } else {
379
                            return Ok(ScanResult::NotFound {
380
                                nearest_position: position,
381
                            });
382
                        }
383
                    }
384
                }
385
            }
386
            ScanResult::NotFound { nearest_position } => {
387
                return Ok(ScanResult::NotFound { nearest_position });
388
            }
389
        }
390
    }
391
16431
}
392

            
393
pub struct EntryScanner<
394
    'a,
395
    Range: RangeBounds<TransactionId>,
396
    Callback: FnMut(LogEntry<'static>) -> bool,
397
> {
398
    pub state: &'a State,
399
    pub ids: Range,
400
    pub vault: Option<&'a dyn AnyVault>,
401
    pub callback: Callback,
402
}
403

            
404
impl<'a, Range, Callback> FileOp<Result<(), Error>> for EntryScanner<'a, Range, Callback>
405
where
406
    Range: RangeBounds<TransactionId>,
407
    Callback: FnMut(LogEntry<'static>) -> bool,
408
{
409
417
    fn execute(mut self, log: &mut dyn File) -> Result<(), Error> {
410
417
        let mut scratch = Vec::with_capacity(PAGE_SIZE);
411
417
        let (start_location, start_transaction, start_length) = match self.ids.start_bound() {
412
411
            Bound::Included(start_key) | Bound::Excluded(start_key) => {
413
411
                match fetch_entry(log, &mut scratch, self.state, *start_key, self.vault)? {
414
                    ScanResult::Found {
415
411
                        entry,
416
411
                        position,
417
411
                        length,
418
411
                    } => (position, Some(entry), length),
419
                    ScanResult::NotFound { nearest_position } => (nearest_position, None, 0),
420
                }
421
            }
422
6
            Bound::Unbounded => (0, None, 0),
423
        };
424

            
425
417
        if let Some(entry) = start_transaction {
426
411
            if self.ids.contains(&entry.id) && !(self.callback)(entry) {
427
11
                return Ok(());
428
400
            }
429
6
        }
430

            
431
        // Continue scanning from this location forward, starting at the next page boundary after the starting transaction
432
406
        let mut next_scan_start = next_page_start(start_location + start_length);
433
        while let ScanResult::Found {
434
14780
            entry,
435
14780
            position,
436
14780
            length,
437
15180
        } = scan_for_transaction(log, &mut scratch, next_scan_start, true, self.vault)?
438
        {
439
14780
            if self.ids.contains(&entry.id) && !(self.callback)(entry) {
440
6
                break;
441
14774
            }
442
14774
            next_scan_start = next_page_start(position + length);
443
        }
444

            
445
406
        Ok(())
446
417
    }
447
}
448

            
449
15180
const fn next_page_start(position: u64) -> u64 {
450
15180
    let page_size = PAGE_SIZE as u64;
451
15180
    (position + page_size - 1) / page_size * page_size
452
15180
}
453

            
454
struct LogWriter {
455
    state: State,
456
    transactions: Vec<LogEntry<'static>>,
457
    vault: Option<Arc<dyn AnyVault>>,
458
}
459

            
460
impl FileOp<Result<(), Error>> for LogWriter {
461
53059
    fn execute(mut self, log: &mut dyn File) -> Result<(), Error> {
462
53059
        let mut log_position = self.state.lock_for_write();
463
53059
        let mut scratch = [0_u8; PAGE_SIZE];
464
53059
        let mut completed_transactions = Vec::with_capacity(self.transactions.len());
465
105421
        for transaction in self.transactions.drain(..) {
466
105421
            if transaction.id > log_position.last_written_transaction {
467
105415
                log_position.last_written_transaction = transaction.id;
468
105415
            } else {
469
6
                return Err(Error::from(ErrorKind::TransactionPushedOutOfOrder));
470
            }
471
105415
            completed_transactions.push((transaction.id, Some(log_position.file_offset)));
472
105415
            let mut bytes = transaction.serialize()?;
473
105415
            if let Some(vault) = &self.vault {
474
12004
                bytes = vault.encrypt(&bytes)?;
475
93411
            }
476
            // Write out the transaction in pages.
477
105415
            let total_length = bytes.len() + 3;
478
105415
            let mut offset = 0;
479
224426
            while offset < bytes.len() {
480
                // Write the page header
481
119011
                let header_len = if offset == 0 {
482
                    // The first page has the length of the payload as the next 3 bytes.
483
105415
                    let length = u32::try_from(bytes.len())
484
105415
                        .map_err(|_| Error::from("transaction too large"))?;
485
105415
                    if length & 0xFF00_0000 != 0 {
486
                        return Err(Error::from("transaction too large"));
487
105415
                    }
488
105415
                    scratch[0] = 1;
489
105415
                    #[allow(clippy::cast_possible_truncation)]
490
105415
                    {
491
105415
                        scratch[1] = (length >> 16) as u8;
492
105415
                        scratch[2] = (length >> 8) as u8;
493
105415
                        scratch[3] = (length & 0xFF) as u8;
494
105415
                    }
495
105415
                    4
496
                } else {
497
                    // Set page_header to have a 0 byte for future pages written.
498
13596
                    scratch[0] = 0;
499
13596
                    1
500
                };
501

            
502
                // Write up to PAGE_SIZE - header_len bytes
503
119011
                let total_bytes_left = total_length - (offset + 3);
504
119011
                let bytes_to_write = total_bytes_left.min(PAGE_SIZE - header_len);
505
119011
                scratch[header_len..bytes_to_write + header_len]
506
119011
                    .copy_from_slice(&bytes[offset..offset + bytes_to_write]);
507
119011
                log.write_all(&scratch)?;
508
119011
                offset += bytes_to_write;
509
119011
                log_position.file_offset += PAGE_SIZE as u64;
510
            }
511
        }
512

            
513
53053
        drop(log_position);
514
53053

            
515
53053
        log.synchronize()?;
516

            
517
53053
        self.state
518
53053
            .note_transaction_ids_completed(&completed_transactions);
519
53053

            
520
53053
        Ok(())
521
53059
    }
522
}
523

            
524
/// An entry in a transaction log.
525
9
#[derive(Eq, PartialEq, Debug)]
526
pub struct LogEntry<'a> {
527
    /// The unique id of this entry.
528
    pub id: TransactionId,
529
    pub(crate) data: Option<ArcBytes<'a>>,
530
}
531

            
532
impl<'a> LogEntry<'a> {
533
    /// Convert this entry into a `'static` lifetime.
534
    #[must_use]
535
443758
    pub fn into_owned(self) -> LogEntry<'static> {
536
443758
        LogEntry {
537
443758
            id: self.id,
538
443758
            data: self.data.map(ArcBytes::into_owned),
539
443758
        }
540
443758
    }
541
}
542

            
543
impl<'a> LogEntry<'a> {
544
    /// Returns the associated data, if any.
545
    #[must_use]
546
12000
    pub const fn data(&self) -> Option<&ArcBytes<'a>> {
547
12000
        self.data.as_ref()
548
12000
    }
549

            
550
    /// Sets the associated data that will be stored in the transaction log.
551
    /// Limited to a length 16,777,208 (2^24 - 8) bytes -- just shy of 16MB.
552
6003
    pub fn set_data(&mut self, data: impl Into<ArcBytes<'a>>) -> Result<(), Error> {
553
6003
        let data = data.into();
554
6003
        if data.len() <= 2_usize.pow(24) - 8 {
555
6001
            self.data = Some(data);
556
6001
            Ok(())
557
        } else {
558
2
            Err(Error::from(ErrorKind::ValueTooLarge))
559
        }
560
6003
    }
561

            
562
105418
    pub(crate) fn serialize(&self) -> Result<Vec<u8>, Error> {
563
105418
        let mut buffer = Vec::with_capacity(8 + self.data.as_ref().map_or(0, |data| data.len()));
564
105418
        // Transaction ID
565
105418
        buffer.write_u64::<BigEndian>(self.id.0)?;
566
105418
        if let Some(data) = &self.data {
567
            // The rest of the entry is the data. Since the header of the log entry
568
            // contains the length, we don't need to waste space encoding it again.
569
6002
            buffer.write_all(data)?;
570
99416
        }
571

            
572
105418
        Ok(buffer)
573
105418
    }
574

            
575
443761
    pub(crate) fn deserialize(mut buffer: &'a [u8]) -> Result<Self, Error> {
576
443761
        let id = TransactionId(buffer.read_u64::<BigEndian>()?);
577
443761
        let data = if buffer.is_empty() {
578
216459
            None
579
        } else {
580
227302
            Some(ArcBytes::from(buffer))
581
        };
582
443761
        Ok(Self { id, data })
583
443761
    }
584
}
585

            
586
1
#[test]
587
1
fn serialization_tests() {
588
1
    let transaction = LogEntry {
589
1
        id: TransactionId(1),
590
1
        data: Some(ArcBytes::from(b"hello")),
591
1
    };
592
1
    let serialized = transaction.serialize().unwrap();
593
1
    let deserialized = LogEntry::deserialize(&serialized).unwrap();
594
1
    assert_eq!(transaction, deserialized);
595

            
596
1
    let transaction = LogEntry {
597
1
        id: TransactionId(u64::MAX),
598
1
        data: None,
599
1
    };
600
1
    let serialized = transaction.serialize().unwrap();
601
1
    let deserialized = LogEntry::deserialize(&serialized).unwrap();
602
1
    assert_eq!(transaction, deserialized);
603

            
604
    // Test the data length limits
605
1
    let mut transaction = LogEntry {
606
1
        id: TransactionId(0),
607
1
        data: None,
608
1
    };
609
1
    let mut big_data = Vec::new();
610
1
    big_data.resize(2_usize.pow(24), 0);
611
1
    let mut big_data = ArcBytes::from(big_data);
612
1
    assert!(matches!(
613
1
        transaction.set_data(big_data.clone()),
614
        Err(Error {
615
            kind: ErrorKind::ValueTooLarge,
616
            ..
617
        })
618
    ));
619

            
620
    // Remove 8 bytes (the transaction id length) and try again.
621
1
    let big_data = big_data.read_bytes(big_data.len() - 8).unwrap();
622
1
    transaction.set_data(big_data).unwrap();
623
1
    let serialized = transaction.serialize().unwrap();
624
1
    let deserialized = LogEntry::deserialize(&serialized).unwrap();
625
1
    assert_eq!(transaction, deserialized);
626
1
}
627

            
628
#[allow(
629
    clippy::cast_precision_loss,
630
    clippy::cast_possible_truncation,
631
    clippy::cast_sign_loss
632
)]
633
423280
fn guess_page(
634
423280
    looking_for: TransactionId,
635
423280
    lower_location: Option<u64>,
636
423280
    lower_id: Option<TransactionId>,
637
423280
    upper_location: u64,
638
423280
    upper_id: TransactionId,
639
423280
) -> Option<u64> {
640
423280
    debug_assert_ne!(looking_for, upper_id);
641
423280
    let total_pages = upper_location / PAGE_SIZE as u64;
642

            
643
423280
    if let (Some(lower_location), Some(lower_id)) = (lower_location, lower_id) {
644
        // Estimate inbetween lower and upper
645
191627
        let current_page = lower_location / PAGE_SIZE as u64;
646
191627
        let delta_from_current = looking_for.0 - lower_id.0;
647
191627
        let local_avg_per_page =
648
191627
            (upper_id.0 - lower_id.0) as f64 / (total_pages - current_page) as f64;
649
191627
        let delta_estimated_pages = (delta_from_current as f64 * local_avg_per_page).floor() as u64;
650
191627
        let guess = lower_location + delta_estimated_pages.max(1) * PAGE_SIZE as u64;
651
191627
        // If our estimate is that the location is beyond or equal to the upper,
652
191627
        // we'll guess the page before it.
653
191627
        if guess >= upper_location {
654
176843
            let capped_guess = upper_location - PAGE_SIZE as u64;
655
176843
            // If the page before the upper is the lower, we can't find the
656
176843
            // transaction in question.
657
176843
            if capped_guess > lower_location {
658
176540
                Some(capped_guess)
659
            } else {
660
303
                None
661
            }
662
        } else {
663
14784
            Some(guess)
664
        }
665
231653
    } else if upper_id > looking_for {
666
        // Go backwards from upper
667
231653
        let avg_per_page = upper_id.0 as f64 / total_pages as f64;
668
231653
        let id_delta = upper_id.0 - looking_for.0;
669
231653
        let delta_estimated_pages = (id_delta as f64 * avg_per_page).ceil() as u64;
670
231653
        let delta_bytes = delta_estimated_pages.saturating_mul(PAGE_SIZE as u64);
671
231653
        Some(upper_location.saturating_sub(delta_bytes))
672
    } else {
673
        None
674
    }
675
423280
}
676

            
677
#[cfg(test)]
678
#[allow(clippy::semicolon_if_nothing_returned, clippy::future_not_send)]
679
mod tests {
680

            
681
    use std::collections::{BTreeSet, HashSet};
682

            
683
    use nanorand::{Pcg64, Rng};
684
    use tempfile::tempdir;
685

            
686
    use super::*;
687
    use crate::{
688
        io::{
689
            any::AnyFileManager,
690
            fs::{StdFile, StdFileManager},
691
            memory::MemoryFileManager,
692
        },
693
        test_util::RotatorVault,
694
        transaction::TransactionManager,
695
        ChunkCache,
696
    };
697

            
698
1
    #[test]
699
1
    fn file_log_file_tests() {
700
1
        log_file_tests("file_log_file", StdFileManager::default(), None, None);
701
1
        log_file_tests(
702
1
            "file_log_file_encrypted",
703
1
            StdFileManager::default(),
704
1
            Some(Arc::new(RotatorVault::new(13))),
705
1
            None,
706
1
        );
707
1
    }
708

            
709
1
    #[test]
710
1
    fn memory_log_file_tests() {
711
1
        log_file_tests("memory_log_file", MemoryFileManager::default(), None, None);
712
1
        log_file_tests(
713
1
            "memory_log_file",
714
1
            MemoryFileManager::default(),
715
1
            Some(Arc::new(RotatorVault::new(13))),
716
1
            None,
717
1
        );
718
1
    }
719

            
720
1
    #[test]
721
1
    fn any_log_file_tests() {
722
1
        log_file_tests("any_file_log_file", AnyFileManager::std(), None, None);
723
1
        log_file_tests("any_memory_log_file", AnyFileManager::memory(), None, None);
724
1
    }
725

            
726
    #[allow(clippy::too_many_lines)]
727
6
    fn log_file_tests<Manager: FileManager>(
728
6
        file_name: &str,
729
6
        file_manager: Manager,
730
6
        vault: Option<Arc<dyn AnyVault>>,
731
6
        cache: Option<ChunkCache>,
732
6
    ) {
733
6
        let temp_dir = crate::test_util::TestDirectory::new(file_name);
734
6
        let context = Context {
735
6
            file_manager,
736
6
            vault,
737
6
            cache,
738
6
        };
739
6
        std::fs::create_dir(&temp_dir).unwrap();
740
6
        let log_path = {
741
6
            let directory: &Path = &temp_dir;
742
6
            directory.join("_transactions")
743
6
        };
744
6

            
745
6
        let mut rng = Pcg64::new_seed(1);
746
6
        let data = (0..PAGE_SIZE * 10)
747
61440
            .map(|_| rng.generate())
748
6
            .collect::<Vec<u8>>();
749

            
750
6006
        for id in 1..=1_000 {
751
6000
            let state = State::from_path(&log_path);
752
6000
            TransactionLog::<Manager::File>::initialize_state(&state, &context).unwrap();
753
6000
            let mut transactions =
754
6000
                TransactionLog::<Manager::File>::open(&log_path, state, context.clone()).unwrap();
755
6000
            assert_eq!(
756
6000
                transactions.state().next_transaction_id(),
757
6000
                TransactionId(id)
758
6000
            );
759
6000
            let mut tx = transactions.new_transaction([&b"hello"[..]]);
760
6000

            
761
6000
            tx.transaction.data = Some(ArcBytes::from(id.to_be_bytes()));
762
6000
            // We want to have varying sizes to try to test the scan algorithm thoroughly.
763
6000
            #[allow(clippy::cast_possible_truncation)]
764
6000
            if id % 2 == 0 {
765
                // Larger than PAGE_SIZE
766
3000
                if id % 3 == 0 {
767
996
                    tx.set_data(data[0..PAGE_SIZE * (id as usize % 10).max(3)].to_vec())
768
996
                        .unwrap();
769
2004
                } else {
770
2004
                    tx.set_data(data[0..PAGE_SIZE * (id as usize % 10).max(2)].to_vec())
771
2004
                        .unwrap();
772
2004
                }
773
3000
            } else {
774
3000
                tx.set_data(data[0..id as usize].to_vec()).unwrap();
775
3000
            }
776

            
777
6000
            assert!(tx.data.as_ref().unwrap().len() > 0);
778

            
779
6000
            transactions.push(vec![tx.transaction]).unwrap();
780
6000
            transactions.close().unwrap();
781
        }
782

            
783
6
        let state = State::from_path(&log_path);
784
6
        if context.vault.is_none() {
785
            // Test that we can't open it with encryption. Without
786
            // https://github.com/khonsulabs/nebari/issues/35, the inverse isn't
787
            // able to be tested.
788
4
            assert!(TransactionLog::<Manager::File>::initialize_state(
789
4
                &state,
790
4
                &Context {
791
4
                    file_manager: context.file_manager.clone(),
792
4
                    vault: Some(Arc::new(RotatorVault::new(13))),
793
4
                    cache: None
794
4
                }
795
4
            )
796
4
            .is_err());
797
2
        }
798

            
799
6
        TransactionLog::<Manager::File>::initialize_state(&state, &context).unwrap();
800
6
        let mut transactions =
801
6
            TransactionLog::<Manager::File>::open(&log_path, state, context).unwrap();
802
6

            
803
6
        let out_of_order = transactions.new_transaction([&b"test"[..]]);
804
6
        transactions
805
6
            .push(vec![
806
6
                transactions.new_transaction([&b"test2"[..]]).transaction,
807
6
            ])
808
6
            .unwrap();
809
6
        assert!(matches!(
810
6
            transactions
811
6
                .push(vec![out_of_order.transaction])
812
6
                .unwrap_err()
813
6
                .kind,
814
            ErrorKind::TransactionPushedOutOfOrder
815
        ));
816

            
817
6
        assert!(transactions.get(TransactionId(0)).unwrap().is_none());
818
6006
        for id in 1..=1_000 {
819
6000
            let transaction = transactions.get(TransactionId(id)).unwrap();
820
6000
            match transaction {
821
6000
                Some(transaction) => {
822
6000
                    assert_eq!(transaction.id, TransactionId(id));
823
6000
                    assert_eq!(
824
6000
                        &data[..transaction.data().unwrap().len()],
825
6000
                        transaction.data().unwrap().as_slice()
826
6000
                    );
827
                }
828
                None => {
829
                    unreachable!("failed to fetch transaction {}", id)
830
                }
831
            }
832
        }
833
6
        assert!(transactions.get(TransactionId(1001)).unwrap().is_none());
834

            
835
        // Test scanning
836
6
        let mut first_ten = Vec::new();
837
6
        transactions
838
60
            .scan(.., |entry| {
839
60
                first_ten.push(entry);
840
60
                first_ten.len() < 10
841
60
            })
842
6
            .unwrap();
843
6
        assert_eq!(first_ten.len(), 10);
844
6
        let mut after_first = None;
845
6
        transactions
846
6
            .scan(TransactionId(first_ten[0].id.0 + 1).., |entry| {
847
6
                after_first = Some(entry);
848
6
                false
849
6
            })
850
6
            .unwrap();
851
6
        assert_eq!(after_first.as_ref(), first_ten.get(1));
852
6
    }
853

            
854
1
    #[test]
855
1
    fn discontiguous_log_file_tests() {
856
1
        let temp_dir = tempdir().unwrap();
857
1
        let file_manager = StdFileManager::default();
858
1
        let context = Context {
859
1
            file_manager,
860
1
            vault: None,
861
1
            cache: None,
862
1
        };
863
1
        let log_path = temp_dir.path().join("_transactions");
864
1
        let mut rng = Pcg64::new_seed(1);
865
1

            
866
1
        let state = State::from_path(&log_path);
867
1
        TransactionLog::<StdFile>::initialize_state(&state, &context).unwrap();
868
1
        let mut transactions = TransactionLog::<StdFile>::open(&log_path, state, context).unwrap();
869
1

            
870
1
        let mut valid_ids = HashSet::new();
871
10001
        for id in 1..=10_000 {
872
10000
            assert_eq!(
873
10000
                transactions.state().next_transaction_id(),
874
10000
                TransactionId(id)
875
10000
            );
876
10000
            let tx = transactions.new_transaction([&b"hello"[..]]);
877
10000
            if rng.generate::<u8>() < 8 {
878
                // skip a few ids.
879
303
                continue;
880
9697
            }
881
9697
            valid_ids.insert(tx.id);
882
9697

            
883
9697
            transactions.push(vec![tx.transaction]).unwrap();
884
        }
885

            
886
10001
        for id in 1..=10_000 {
887
10000
            let transaction = transactions.get(TransactionId(id)).unwrap();
888
10000
            match transaction {
889
9697
                Some(transaction) => assert_eq!(transaction.id, TransactionId(id)),
890
                None => {
891
303
                    assert!(!valid_ids.contains(&TransactionId(id)));
892
                }
893
            }
894
        }
895
1
    }
896

            
897
1
    #[test]
898
1
    fn file_log_manager_tests() {
899
1
        log_manager_tests("file_log_manager", StdFileManager::default(), None, None);
900
1
    }
901

            
902
1
    #[test]
903
1
    fn memory_log_manager_tests() {
904
1
        log_manager_tests(
905
1
            "memory_log_manager",
906
1
            MemoryFileManager::default(),
907
1
            None,
908
1
            None,
909
1
        );
910
1
    }
911

            
912
1
    #[test]
913
1
    fn any_log_manager_tests() {
914
1
        log_manager_tests("any_log_manager", AnyFileManager::std(), None, None);
915
1
        log_manager_tests("any_log_manager", AnyFileManager::memory(), None, None);
916
1
    }
917

            
918
1
    #[test]
919
1
    fn file_encrypted_log_manager_tests() {
920
1
        log_manager_tests(
921
1
            "encrypted_file_log_manager",
922
1
            MemoryFileManager::default(),
923
1
            Some(Arc::new(RotatorVault::new(13))),
924
1
            None,
925
1
        );
926
1
    }
927

            
928
5
    fn log_manager_tests<Manager: FileManager>(
929
5
        file_name: &str,
930
5
        file_manager: Manager,
931
5
        vault: Option<Arc<dyn AnyVault>>,
932
5
        cache: Option<ChunkCache>,
933
5
    ) {
934
5
        let temp_dir = crate::test_util::TestDirectory::new(file_name);
935
5
        std::fs::create_dir(&temp_dir).unwrap();
936
5
        let context = Context {
937
5
            file_manager,
938
5
            vault,
939
5
            cache,
940
5
        };
941
5
        let manager = TransactionManager::spawn(&temp_dir, context).unwrap();
942
5
        assert_eq!(manager.current_transaction_id(), None);
943
5
        assert_eq!(manager.len(), 0);
944
5
        assert!(manager.is_empty());
945

            
946
5
        let mut handles = Vec::new();
947
55
        for _ in 0..10 {
948
50
            let manager = manager.clone();
949
50
            handles.push(std::thread::spawn(move || {
950
50050
                for id in 0_u32..1_000 {
951
50000
                    let tx = manager.new_transaction([&id.to_be_bytes()[..]]);
952
50000
                    tx.commit().unwrap();
953
50000
                }
954
50
            }));
955
50
        }
956

            
957
55
        for handle in handles {
958
50
            handle.join().unwrap();
959
50
        }
960

            
961
5
        assert_eq!(
962
5
            manager.current_transaction_id(),
963
5
            Some(TransactionId(10_000))
964
5
        );
965
5
        assert_eq!(manager.next_transaction_id(), TransactionId(10_001));
966

            
967
5
        assert!(manager
968
5
            .transaction_was_successful(manager.current_transaction_id().unwrap())
969
5
            .unwrap());
970

            
971
5
        assert!(!manager
972
5
            .transaction_was_successful(manager.next_transaction_id())
973
5
            .unwrap());
974

            
975
5
        let mut ten = None;
976
5
        manager
977
5
            .scan(TransactionId(10).., |entry| {
978
5
                ten = Some(entry);
979
5
                false
980
5
            })
981
5
            .unwrap();
982
5
        assert_eq!(ten.unwrap().id, TransactionId(10));
983
5
    }
984

            
985
1
    #[test]
986
1
    fn file_out_of_order_log_manager_tests() {
987
1
        out_of_order_log_manager_tests(
988
1
            "file_out_of_order_log_manager",
989
1
            StdFileManager::default(),
990
1
            None,
991
1
            None,
992
1
        );
993
1
    }
994

            
995
1
    #[test]
996
1
    fn memory_out_of_order_log_manager_tests() {
997
1
        out_of_order_log_manager_tests(
998
1
            "memory_out_of_order_log_manager",
999
1
            MemoryFileManager::default(),
1
            None,
1
            None,
1
        );
1
    }

            
1
    #[test]
1
    fn any_out_of_order_log_manager_tests() {
1
        out_of_order_log_manager_tests(
1
            "any_out_of_order_log_manager",
1
            AnyFileManager::std(),
1
            None,
1
            None,
1
        );
1
        out_of_order_log_manager_tests(
1
            "any_out_of_order_log_manager",
1
            AnyFileManager::memory(),
1
            None,
1
            None,
1
        );
1
    }

            
4
    fn out_of_order_log_manager_tests<Manager: FileManager>(
4
        file_name: &str,
4
        file_manager: Manager,
4
        vault: Option<Arc<dyn AnyVault>>,
4
        cache: Option<ChunkCache>,
4
    ) {
4
        let temp_dir = crate::test_util::TestDirectory::new(file_name);
4
        std::fs::create_dir(&temp_dir).unwrap();
4
        let context = Context {
4
            file_manager,
4
            vault,
4
            cache,
4
        };
4
        let manager = TransactionManager::spawn(&temp_dir, context).unwrap();
4
        let mut rng = Pcg64::new_seed(1);

            
404
        for batch in 1..=100_u8 {
400
            println!("New batch");
400
            // Generate a bunch of transactions.
400
            let mut handles = Vec::new();
20200
            for tree in 1..=batch {
20200
                handles.push(manager.new_transaction([&tree.to_be_bytes()[..]]));
20200
            }
400
            rng.shuffle(&mut handles);
400
            let (handle_sender, handle_receiver) = flume::unbounded();
400
            let mut should_commit_handles = Vec::new();
400
            let mut expected_ids = BTreeSet::new();
20200
            for (index, handle) in handles.into_iter().enumerate() {
20200
                let should_commit_handle = rng.generate::<f32>() > 0.25 || expected_ids.is_empty();
20200
                if should_commit_handle {
15120
                    expected_ids.insert(handle.id);
15120
                }
20200
                should_commit_handles.push(should_commit_handle);
20200
                handle_sender.send((index, handle)).unwrap();
            }
400
            let should_commit_handles = Arc::new(should_commit_handles);
400
            let mut threads = Vec::new();
20200
            for _ in 1..=batch {
20200
                let handle_receiver = handle_receiver.clone();
20200
                let should_commit_handles = should_commit_handles.clone();
20200
                threads.push(std::thread::spawn(move || {
20200
                    let (handle_index, handle) = handle_receiver.recv().unwrap();
20200
                    if should_commit_handles[handle_index] {
15120
                        println!("Committing handle {}", handle.id);
15120
                        handle.commit().unwrap();
15120
                    } else {
5080
                        println!("Dropping handle {}", handle.id);
5080
                        handle.rollback();
5080
                    }
20200
                }));
20200
            }
20600
            for thread in threads {
20200
                thread.join().unwrap();
20200
            }
400
            manager
15120
                .scan(dbg!(*expected_ids.iter().next().unwrap()).., |tx| {
15120
                    expected_ids.remove(&tx.id);
15120
                    true
15120
                })
400
                .unwrap();
400
            assert!(expected_ids.is_empty(), "{:?}", expected_ids);
        }
4
    }
}