use std::collections::HashMap;
use std::fmt::Display;
use redis::{Commands, ConnectionLike, ErrorKind, RedisError};
use crate::error::{ClientCreationError, ClientOperationError};
pub(crate) type DbClient = r2d2::Pool<redis::Client>;
impl From<RedisError> for ClientOperationError {
fn from(err: RedisError) -> Self {
match err.kind() {
ErrorKind::IoError => Self::IoError(err.into()),
ErrorKind::ResponseError
| ErrorKind::ExecAbortError
| ErrorKind::BusyLoadingError
| ErrorKind::NoScriptError
| ErrorKind::Moved
| ErrorKind::Ask
| ErrorKind::TryAgain
| ErrorKind::ClusterDown
| ErrorKind::CrossSlot
| ErrorKind::MasterDown
| ErrorKind::ExtensionError => Self::Server(err.into()),
ErrorKind::AuthenticationFailed
| ErrorKind::TypeError
| ErrorKind::InvalidClientConfig
| ErrorKind::ClientError
| ErrorKind::ReadOnly => Self::Unknown(err.into()),
_ => Self::Unknown(err.into()),
}
}
}
impl From<RedisError> for ClientCreationError {
fn from(err: RedisError) -> Self {
match err.kind() {
ErrorKind::IoError => Self::IoError(err.to_string()),
ErrorKind::AuthenticationFailed => Self::Authentication(err.to_string()),
ErrorKind::ResponseError
| ErrorKind::TypeError
| ErrorKind::ExecAbortError
| ErrorKind::BusyLoadingError
| ErrorKind::NoScriptError
| ErrorKind::InvalidClientConfig
| ErrorKind::Moved
| ErrorKind::Ask
| ErrorKind::TryAgain
| ErrorKind::ClusterDown
| ErrorKind::CrossSlot
| ErrorKind::MasterDown
| ErrorKind::ClientError
| ErrorKind::ExtensionError
| ErrorKind::ReadOnly => Self::Unknown(err.to_string()),
_ => Self::Unknown(err.to_string()),
}
}
}
impl From<r2d2::Error> for ClientCreationError {
fn from(err: r2d2::Error) -> Self {
Self::IoError(err.to_string())
}
}
impl From<r2d2::Error> for ClientOperationError {
fn from(err: r2d2::Error) -> Self {
Self::IoError(err.into())
}
}
macro_rules! try_again_redis_op {
($operation:expr) => {{
let mut counter = 0;
loop {
match $operation {
Err(err) if err.kind() == ErrorKind::TryAgain => {
counter += 1;
if counter == 3 {
break Err(err);
}
}
val => break val,
}
}
}};
}
#[derive(Clone)]
pub struct RedisClient {
client: DbClient,
key_prefix: String,
}
impl RedisClient {
pub fn new(
redis_url: impl ToString,
key_prefix: Option<String>,
) -> Result<Self, ClientCreationError> {
let client = try_again_redis_op! {
redis::Client::open(redis_url.to_string())
}?;
let min_connection_pool_size = Some(0);
let client = r2d2::Pool::builder()
.min_idle(min_connection_pool_size)
.idle_timeout(Some(std::time::Duration::from_secs(5 * 60)))
.connection_timeout(std::time::Duration::from_secs(10))
.build(client)?;
Ok(RedisClient {
client,
key_prefix: key_prefix.unwrap_or("".into()),
})
}
pub fn with_prefix_ext(mut self, prefix_extension: impl Display) -> Self {
self.key_prefix = format!("{}:{}", self.key_prefix, prefix_extension);
self
}
fn handle_key_prefix(&self, key: impl Display) -> String {
if !self.key_prefix.is_empty() {
format!("{}:{}", self.key_prefix, key)
} else {
key.to_string()
}
}
fn strip_key_prefix<'a>(&self, key: &'a str) -> &'a str {
key.strip_prefix(&format!("{}:", self.key_prefix))
.unwrap_or(key)
}
#[tracing::instrument(level = "debug", skip_all)]
pub fn find_keys(&self, query: impl Display) -> Result<Vec<String>, ClientOperationError> {
try_again_redis_op! {
self.client
.get()?
.keys(self.handle_key_prefix(&query))
}
.map_err(Into::into)
}
#[tracing::instrument(level = "debug", skip_all)]
pub fn query_get(
&self,
query: impl Display,
) -> Result<HashMap<String, String>, ClientOperationError> {
let keys: Vec<String> = self.find_keys(query)?;
if keys.is_empty() {
return Ok(Default::default());
};
let values: Vec<_> = try_again_redis_op! {
self
.client
.get()?
.get(&keys)
}?;
Ok(keys
.into_iter()
.map(|key| self.strip_key_prefix(&key).to_string())
.zip(values)
.collect())
}
#[tracing::instrument(level = "debug", skip_all)]
pub fn get(&self, key: impl Display) -> Result<Option<String>, ClientOperationError> {
let key = self.handle_key_prefix(key);
try_again_redis_op! {
self.client
.get()?
.get(&key)
}
.map_err(Into::into)
}
#[tracing::instrument(level = "debug", skip_all)]
pub fn get_many<T: Display>(&self, keys: &[T]) -> Result<Vec<String>, ClientOperationError> {
let keys: Vec<_> = keys.iter().map(|val| self.handle_key_prefix(val)).collect();
let data: Vec<Option<String>> = try_again_redis_op! {
self.client
.get()?
.mget(&keys)
}
.map_err(Into::<ClientOperationError>::into)?;
Ok(data.into_iter().flatten().collect())
}
#[tracing::instrument(level = "debug", skip_all)]
pub fn set(&self, key: impl Display, data: &str) -> Result<(), ClientOperationError> {
try_again_redis_op! {
self.client
.get()?
.set(self.handle_key_prefix(&key), data)
}
.map_err(Into::into)
}
#[tracing::instrument(level = "debug", skip_all)]
pub fn delete(&self, key: impl Display) -> Result<(), ClientOperationError> {
try_again_redis_op! {
self.client
.get()?
.del(self.handle_key_prefix(&key))
}
.map_err(Into::into)
}
#[tracing::instrument(level = "debug", skip_all)]
pub fn is_alive(&self) -> Result<bool, ClientOperationError> {
Ok(self.client.get()?.check_connection())
}
}