1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//
// Wildland Project
//
// Copyright © 2022 Golem Foundation
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 3 as published by
// the Free Software Foundation.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.

use std::fmt::Display;

use s_macro::s;
use wildland_corex::{LocalSecureStorage, LssError, LssResult};

#[derive(Debug, Clone)]
pub struct SledLss {
    db: sled::Db,
}

impl SledLss {
    pub fn new(path: String) -> Self {
        SledLss {
            db: sled::open(path).unwrap(),
        }
    }

    pub fn boxed(&self) -> Box<Self> {
        Box::new(self.clone())
    }
}

fn to_lsserror<T: Display>(error: T) -> LssError {
    LssError::Error(s!(error))
}

impl LocalSecureStorage for SledLss {
    #[tracing::instrument(level = "trace", err(Debug), skip(self))]
    fn insert(&self, key: String, value: String) -> LssResult<Option<String>> {
        let value = self.db.insert(key, value.as_bytes()).map_err(to_lsserror)?;
        match value {
            Some(value) => Ok(Some(
                String::from_utf8(value.to_vec()).map_err(to_lsserror)?,
            )),
            None => Ok(None),
        }
    }

    #[tracing::instrument(level = "trace", err(Debug), err(Debug), skip(self))]
    fn get(&self, key: String) -> LssResult<Option<String>> {
        let value = self.db.get(key).map_err(to_lsserror)?;
        match value {
            Some(value) => Ok(Some(
                String::from_utf8(value.to_vec()).map_err(to_lsserror)?,
            )),
            None => Ok(None),
        }
    }

    #[tracing::instrument(level = "trace", err(Debug), skip(self))]
    fn contains_key(&self, key: String) -> LssResult<bool> {
        self.db.contains_key(key).map_err(to_lsserror)
    }

    #[tracing::instrument(level = "trace", skip(self))]
    fn is_empty(&self) -> LssResult<bool> {
        // esavier:devnote:have no idea how to handle this otherwise
        Ok(self.db.is_empty())
    }

    #[tracing::instrument(level = "trace", err(Debug), skip(self))]
    fn keys(&self) -> LssResult<Vec<String>> {
        let keys = self
            .db
            .iter()
            .map(|item| String::from_utf8(item.unwrap().0.to_vec()).unwrap())
            .collect();
        Ok(keys)
    }

    #[tracing::instrument(level = "trace", err(Debug), skip(self))]
    fn remove(&self, key: String) -> LssResult<Option<String>> {
        let value = self.db.remove(key).map_err(to_lsserror)?;
        match value {
            Some(value) => Ok(Some(
                String::from_utf8(value.to_vec()).map_err(to_lsserror)?,
            )),
            None => Ok(None),
        }
    }

    #[tracing::instrument(level = "trace", err(Debug), skip(self))]
    fn len(&self) -> LssResult<usize> {
        Ok(self.db.len())
    }

    #[tracing::instrument(level = "trace", err(Debug), skip(self))]
    fn keys_starting_with(&self, prefix: String) -> LssResult<Vec<String>> {
        let keys = self
            .db
            .iter()
            .map(|item| String::from_utf8(item.unwrap().0.to_vec()).unwrap())
            .filter(|x| x.starts_with(&prefix))
            .collect();
        Ok(keys)
    }
}