openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_redis.rs1use super::TransactionCounterTrait;
9use crate::models::RepositoryError;
10use crate::repositories::redis_base::RedisRepository;
11use async_trait::async_trait;
12use redis::aio::ConnectionManager;
13use redis::AsyncCommands;
14use std::fmt;
15use std::sync::Arc;
16use tracing::debug;
17
18const COUNTER_PREFIX: &str = "transaction_counter";
19
20#[derive(Clone)]
21pub struct RedisTransactionCounter {
22 pub client: Arc<ConnectionManager>,
23 pub key_prefix: String,
24}
25
26impl RedisRepository for RedisTransactionCounter {}
27
28impl fmt::Debug for RedisTransactionCounter {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 f.debug_struct("RedisTransactionCounter")
31 .field("key_prefix", &self.key_prefix)
32 .finish()
33 }
34}
35
36impl RedisTransactionCounter {
37 pub fn new(
38 connection_manager: Arc<ConnectionManager>,
39 key_prefix: String,
40 ) -> Result<Self, RepositoryError> {
41 if key_prefix.is_empty() {
42 return Err(RepositoryError::InvalidData(
43 "Redis key prefix cannot be empty".to_string(),
44 ));
45 }
46
47 Ok(Self {
48 client: connection_manager,
49 key_prefix,
50 })
51 }
52
53 fn counter_key(&self, relayer_id: &str, address: &str) -> String {
55 format!(
56 "{}:{}:{}:{}",
57 self.key_prefix, COUNTER_PREFIX, relayer_id, address
58 )
59 }
60}
61
62#[async_trait]
63impl TransactionCounterTrait for RedisTransactionCounter {
64 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
65 if relayer_id.is_empty() {
66 return Err(RepositoryError::InvalidData(
67 "Relayer ID cannot be empty".to_string(),
68 ));
69 }
70
71 if address.is_empty() {
72 return Err(RepositoryError::InvalidData(
73 "Address cannot be empty".to_string(),
74 ));
75 }
76
77 let key = self.counter_key(relayer_id, address);
78 debug!(relayer_id = %relayer_id, address = %address, "getting counter for relayer and address");
79
80 let mut conn = self.client.as_ref().clone();
81
82 let value: Option<u64> = conn
83 .get(&key)
84 .await
85 .map_err(|e| self.map_redis_error(e, "get_counter"))?;
86
87 debug!(value = ?value, "retrieved counter value");
88 Ok(value)
89 }
90
91 async fn get_and_increment(
92 &self,
93 relayer_id: &str,
94 address: &str,
95 ) -> Result<u64, RepositoryError> {
96 if relayer_id.is_empty() {
97 return Err(RepositoryError::InvalidData(
98 "Relayer ID cannot be empty".to_string(),
99 ));
100 }
101
102 if address.is_empty() {
103 return Err(RepositoryError::InvalidData(
104 "Address cannot be empty".to_string(),
105 ));
106 }
107
108 let key = self.counter_key(relayer_id, address);
109 debug!(relayer_id = %relayer_id, address = %address, "getting and incrementing counter for relayer and address");
110
111 let mut conn = self.client.as_ref().clone();
112
113 let new_value: u64 = conn
115 .incr(&key, 1)
116 .await
117 .map_err(|e| self.map_redis_error(e, "get_and_increment"))?;
118
119 let current = new_value.saturating_sub(1);
120
121 debug!(from = %current, to = %(current + 1), "counter incremented");
122 Ok(current)
123 }
124
125 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
126 if relayer_id.is_empty() {
127 return Err(RepositoryError::InvalidData(
128 "Relayer ID cannot be empty".to_string(),
129 ));
130 }
131
132 if address.is_empty() {
133 return Err(RepositoryError::InvalidData(
134 "Address cannot be empty".to_string(),
135 ));
136 }
137
138 let key = self.counter_key(relayer_id, address);
139 debug!(relayer_id = %relayer_id, address = %address, "decrementing counter for relayer and address");
140
141 let mut conn = self.client.as_ref().clone();
142
143 let exists: bool = conn
145 .exists(&key)
146 .await
147 .map_err(|e| self.map_redis_error(e, "check_counter_exists"))?;
148
149 if !exists {
150 return Err(RepositoryError::NotFound(format!(
151 "Counter not found for relayer {} and address {}",
152 relayer_id, address
153 )));
154 }
155
156 let new_value: i64 = conn
158 .decr(&key, 1)
159 .await
160 .map_err(|e| self.map_redis_error(e, "decrement_counter"))?;
161
162 let new_value = if new_value < 0 {
163 let _: () = conn
165 .set(&key, 0)
166 .await
167 .map_err(|e| self.map_redis_error(e, "correct_negative_counter"))?;
168 0u64
169 } else {
170 new_value as u64
171 };
172
173 debug!(new_value = %new_value, "counter decremented");
174 Ok(new_value)
175 }
176
177 async fn set(
178 &self,
179 relayer_id: &str,
180 address: &str,
181 value: u64,
182 ) -> Result<(), RepositoryError> {
183 if relayer_id.is_empty() {
184 return Err(RepositoryError::InvalidData(
185 "Relayer ID cannot be empty".to_string(),
186 ));
187 }
188
189 if address.is_empty() {
190 return Err(RepositoryError::InvalidData(
191 "Address cannot be empty".to_string(),
192 ));
193 }
194
195 let key = self.counter_key(relayer_id, address);
196 debug!(relayer_id = %relayer_id, address = %address, value = %value, "setting counter for relayer and address");
197
198 let mut conn = self.client.as_ref().clone();
199
200 let _: () = conn
201 .set(&key, value)
202 .await
203 .map_err(|e| self.map_redis_error(e, "set_counter"))?;
204
205 debug!(value = %value, "counter set");
206 Ok(())
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use redis::aio::ConnectionManager;
214 use std::sync::Arc;
215 use tokio;
216 use uuid::Uuid;
217
218 async fn setup_test_repo() -> RedisTransactionCounter {
219 let redis_url =
220 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
221 let client = redis::Client::open(redis_url).expect("Failed to create Redis client");
222 let connection_manager = ConnectionManager::new(client)
223 .await
224 .expect("Failed to create Redis connection manager");
225
226 RedisTransactionCounter::new(Arc::new(connection_manager), "test_counter".to_string())
227 .expect("Failed to create Redis transaction counter")
228 }
229
230 #[tokio::test]
231 #[ignore = "Requires active Redis instance"]
232 async fn test_get_nonexistent_counter() {
233 let repo = setup_test_repo().await;
234 let random_id = Uuid::new_v4().to_string();
235 let result = repo.get(&random_id, "0x1234").await.unwrap();
236 assert_eq!(result, None);
237 }
238
239 #[tokio::test]
240 #[ignore = "Requires active Redis instance"]
241 async fn test_set_and_get_counter() {
242 let repo = setup_test_repo().await;
243 let relayer_id = uuid::Uuid::new_v4().to_string();
244 let address = uuid::Uuid::new_v4().to_string();
245
246 repo.set(&relayer_id, &address, 100).await.unwrap();
247 let result = repo.get(&relayer_id, &address).await.unwrap();
248 assert_eq!(result, Some(100));
249 }
250
251 #[tokio::test]
252 #[ignore = "Requires active Redis instance"]
253 async fn test_get_and_increment() {
254 let repo = setup_test_repo().await;
255 let relayer_id = uuid::Uuid::new_v4().to_string();
256 let address = uuid::Uuid::new_v4().to_string();
257
258 let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
260 assert_eq!(result, 0);
261
262 let current = repo.get(&relayer_id, &address).await.unwrap();
263 assert_eq!(current, Some(1));
264
265 let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
267 assert_eq!(result, 1);
268
269 let current = repo.get(&relayer_id, &address).await.unwrap();
270 assert_eq!(current, Some(2));
271 }
272
273 #[tokio::test]
274 #[ignore = "Requires active Redis instance"]
275 async fn test_decrement() {
276 let repo = setup_test_repo().await;
277 let relayer_id = uuid::Uuid::new_v4().to_string();
278 let address = uuid::Uuid::new_v4().to_string();
279
280 repo.set(&relayer_id, &address, 5).await.unwrap();
282
283 let result = repo.decrement(&relayer_id, &address).await.unwrap();
285 assert_eq!(result, 4);
286
287 let current = repo.get(&relayer_id, &address).await.unwrap();
288 assert_eq!(current, Some(4));
289 }
290
291 #[tokio::test]
292 #[ignore = "Requires active Redis instance"]
293 async fn test_decrement_not_found() {
294 let repo = setup_test_repo().await;
295 let result = repo.decrement("nonexistent", "0x1234").await;
296 assert!(matches!(result, Err(RepositoryError::NotFound(_))));
297 }
298
299 #[tokio::test]
300 #[ignore = "Requires active Redis instance"]
301 async fn test_empty_validation() {
302 let repo = setup_test_repo().await;
303
304 let result = repo.get("", "0x1234").await;
306 assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
307
308 let result = repo.get("relayer", "").await;
310 assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
311 }
312
313 #[tokio::test]
314 #[ignore = "Requires active Redis instance"]
315 async fn test_multiple_relayers() {
316 let repo = setup_test_repo().await;
317 let relayer_1 = uuid::Uuid::new_v4().to_string();
318 let relayer_2 = uuid::Uuid::new_v4().to_string();
319 let address_1 = uuid::Uuid::new_v4().to_string();
320 let address_2 = uuid::Uuid::new_v4().to_string();
321
322 repo.set(&relayer_1, &address_1, 100).await.unwrap();
324 repo.set(&relayer_1, &address_2, 200).await.unwrap();
325 repo.set(&relayer_2, &address_1, 300).await.unwrap();
326
327 assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
329 assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
330 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
331
332 assert_eq!(
334 repo.get_and_increment(&relayer_1, &address_1)
335 .await
336 .unwrap(),
337 100
338 );
339 assert_eq!(
340 repo.get_and_increment(&relayer_1, &address_1)
341 .await
342 .unwrap(),
343 101
344 );
345 assert_eq!(
346 repo.get_and_increment(&relayer_1, &address_2)
347 .await
348 .unwrap(),
349 200
350 );
351 assert_eq!(
352 repo.get_and_increment(&relayer_1, &address_2)
353 .await
354 .unwrap(),
355 201
356 );
357 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
358 }
359
360 #[tokio::test]
361 #[ignore = "Requires active Redis instance"]
362 async fn test_concurrent_get_and_increment() {
363 let repo = setup_test_repo().await;
364 let relayer_id = uuid::Uuid::new_v4().to_string();
365 let address = uuid::Uuid::new_v4().to_string();
366
367 repo.set(&relayer_id, &address, 100).await.unwrap();
369
370 let handles: Vec<_> = (0..10)
372 .map(|_| {
373 let repo = repo.clone();
374 let relayer_id = relayer_id.clone();
375 let address = address.clone();
376 tokio::spawn(
377 async move { repo.get_and_increment(&relayer_id, &address).await.unwrap() },
378 )
379 })
380 .collect();
381
382 let mut results = Vec::new();
384 for handle in handles {
385 results.push(handle.await.unwrap());
386 }
387
388 results.sort();
390
391 let expected: Vec<u64> = (100..110).collect();
393 assert_eq!(results, expected);
394
395 let final_value = repo.get(&relayer_id, &address).await.unwrap();
397 assert_eq!(final_value, Some(110));
398 }
399}