diff options
Diffstat (limited to 'src/net/packet.rs')
| -rw-r--r-- | src/net/packet.rs | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/src/net/packet.rs b/src/net/packet.rs new file mode 100644 index 0000000..2d97504 --- /dev/null +++ b/src/net/packet.rs @@ -0,0 +1,73 @@ +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use std::io; +use std::mem; + +#[derive(Debug, Deserialize, Serialize)] +pub(super) enum Packet<P: 'static + Send + Debug> { + Cargo(P), + Disconnect, +} + +#[derive(Debug, thiserror::Error)] +pub enum PacketRwError { + #[error("packet could not be properly deserialised: {0}")] + DeserialiseError(bincode::Error), + #[error("packet could not be properly serialised: {0}")] + SerialiseError(bincode::Error), + #[error("unable to read packet from stream: {0}")] + IOError(io::Error), + #[error("connection was closed from the remote end")] + Closed, +} + +impl<P: 'static + Send + Debug + DeserializeOwned + Serialize> Packet<P> { + pub fn write_to_stream(&self, stream: &mut impl io::Write) -> Result<(), PacketRwError> { + let data: Vec<u8> = + bincode::serialize(&self).map_err(|err| PacketRwError::SerialiseError(err))?; + + // Write head with packet length + assert!(data.len() as u64 <= u32::MAX as u64); + let len = data.len() as u32; + let len = bincode::serialize(&len).map_err(|err| PacketRwError::SerialiseError(err))?; + stream + .write_all(&len) + .map_err(|err| PacketRwError::IOError(err))?; + + // Write the data of the packet and pray all errors are caught. + Ok(stream + .write_all(&data) + .map_err(|err| PacketRwError::IOError(err))?) + } + + pub fn read_from_stream(stream: &mut impl io::Read) -> Result<Self, PacketRwError> { + // Read packet head which informs us of the length. + let mut len = vec![0; mem::size_of::<u32>()]; + stream.read_exact(&mut len).map_err(|err| { + if err.kind() == io::ErrorKind::UnexpectedEof { + PacketRwError::Closed + } else { + PacketRwError::IOError(err) + } + })?; + let len: u32 = bincode::deserialize(&len) + .expect("Unable to deserialise length of packet. Stream is corrupted."); + + // Read all data from the packet according to the length. + let mut data = vec![0; len as usize]; + match stream.read_exact(&mut data) { + Ok(()) => { + let res: Result<Self, bincode::Error> = bincode::deserialize(&data); + Ok(res.map_err(|err| PacketRwError::DeserialiseError(err))?) + } + Err(err) => { + if err.kind() == io::ErrorKind::UnexpectedEof { + Err(PacketRwError::Closed) + } else { + Err(PacketRwError::IOError(err)) + } + } + } + } +} |
