1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//
// Wildland Project
//
// Copyright © 2022 Golem Foundation
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 3 as published by
// the Free Software Foundation.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.

use wildland_corex::dfs::interface::stream::{IStream, OStream, StreamErr};
use wildland_corex::dfs::interface::{IStreamResult, OStreamResult};

use super::{Core, CIPHER_DECRYPTION_SHRINKAGE_RATE, CIPHER_ENCRYPTION_EXPANSION_RATE};
use crate::encryption::interface::{EncryptionModule, EncryptionModuleError};

pub struct EncryptingIStream {
    inner: Box<dyn IStream>,
    encryption_module: Core,
}

impl EncryptingIStream {
    pub fn new(inner: Box<dyn IStream>, encryption_module: Core) -> Self {
        EncryptingIStream {
            inner,
            encryption_module,
        }
    }
}

impl IStream for EncryptingIStream {
    fn read(&mut self, bytes_count: usize) -> IStreamResult {
        let max_bytes_to_read =
            (bytes_count as f64 * CIPHER_DECRYPTION_SHRINKAGE_RATE).floor() as usize;
        let max_bytes_to_read = if max_bytes_to_read <= 3 {
            max_bytes_to_read
        } else {
            max_bytes_to_read - max_bytes_to_read % 3
        };

        let mut bytes = self.inner.read(max_bytes_to_read)?;

        loop {
            let reminder_len = bytes.len() % 3;
            if reminder_len != 0 {
                let additional_bytes = self.inner.read(3 - reminder_len)?;
                if additional_bytes.is_empty() {
                    break;
                } else {
                    bytes.extend(additional_bytes);
                }
            } else {
                break;
            }
        }

        let out = self.encryption_module.encode_data(&bytes)?;
        IStreamResult::Ok(out)
    }

    fn total_size(&self) -> usize {
        (self.inner.total_size() as f64 * CIPHER_ENCRYPTION_EXPANSION_RATE).ceil() as usize
    }
}

pub struct EncryptingOStream {
    inner: Box<dyn OStream>,
    encryption_module: Core,
    reminder: Vec<u8>,
}

impl EncryptingOStream {
    pub fn new(inner: Box<dyn OStream>, encryption_module: Core) -> Self {
        EncryptingOStream {
            inner,
            encryption_module,
            reminder: vec![],
        }
    }
}

impl OStream for EncryptingOStream {
    fn write(&mut self, data: Vec<u8>) -> OStreamResult {
        let data = if self.reminder.is_empty() {
            data
        } else {
            [self.reminder.as_slice(), data.as_slice()].concat()
        };

        let data_len = data.len();
        let split_point = data_len / 4usize * 4;
        let (bytes, reminder) = data.split_at(split_point);
        self.reminder = reminder.into();

        let out = self.encryption_module.decode_data(bytes)?;
        self.inner.write(out)
    }

    fn flush(mut self: Box<Self>) -> OStreamResult {
        if !self.reminder.is_empty() {
            let out = self.encryption_module.decode_data(&self.reminder)?;
            self.inner.write(out)
        } else {
            Ok(())
        }
    }
}

impl From<EncryptionModuleError> for StreamErr {
    fn from(e: EncryptionModuleError) -> Self {
        Self {
            code: 0,
            msg: format!("{}", e),
        }
    }
}