aboutsummaryrefslogtreecommitdiffstats
path: root/src/proto/codec.rs
blob: d8f9b1095150a1597e58fdad9a328acdbcced5fa (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
use bytes::{BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder};

use crate::proto::error::CodecError;
use crate::proto::message::IrcMessage;
use crate::proto::parser::parse;
use crate::proto::serializer::serialize;

const MAX_LINE_LENGTH: usize = 512;

pub struct IrcCodec {
    max_line_length: usize,
}

impl IrcCodec {
    pub fn new() -> Self {
        Self {
            max_line_length: MAX_LINE_LENGTH,
        }
    }

    pub fn with_max_length(max_line_length: usize) -> Self {
        Self { max_line_length }
    }
}

impl Default for IrcCodec {
    fn default() -> Self {
        Self::new()
    }
}

impl Decoder for IrcCodec {
    type Item = IrcMessage;
    type Error = CodecError;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        loop {
            let newline_pos = src.iter().position(|&b| b == b'\n');

            match newline_pos {
                None => {
                    if src.len() > self.max_line_length {
                        return Err(CodecError::Parse(
                            crate::proto::error::ParseError::MessageTooLong,
                        ));
                    }
                    return Ok(None);
                }
                Some(pos) => {
                    let line_bytes = src.split_to(pos + 1);

                    let line = &line_bytes[..line_bytes.len() - 1]; // strip \n
                    let line = if line.last() == Some(&b'\r') {
                        &line[..line.len() - 1] // strip \r
                    } else {
                        line
                    };

                    // Skip empty lines silently
                    if line.is_empty() {
                        continue;
                    }

                    let line_str = std::str::from_utf8(line).map_err(|_| {
                        CodecError::Io(std::io::Error::new(
                            std::io::ErrorKind::InvalidData,
                            "IRC message is not valid UTF-8",
                        ))
                    })?;

                    let msg = parse(line_str)?;
                    return Ok(Some(msg));
                }
            }
        }
    }
}

impl Encoder<IrcMessage> for IrcCodec {
    type Error = CodecError;

    fn encode(&mut self, msg: IrcMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
        let line = serialize(&msg);

        // +2 for \r\n
        if line.len() + 2 > self.max_line_length {
            return Err(CodecError::Parse(
                crate::proto::error::ParseError::MessageTooLong,
            ));
        }

        dst.reserve(line.len() + 2);
        dst.put_slice(line.as_bytes());
        dst.put_slice(b"\r\n");
        Ok(())
    }
}