openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_redis.rs

1//! Redis implementation of the transaction counter.
2//!
3//! This module provides a Redis-based implementation of the `TransactionCounterTrait`,
4//! allowing transaction counters to be stored and retrieved from a Redis database.
5//! The implementation includes comprehensive error handling, logging, and atomic operations
6//! to ensure consistency when incrementing and decrementing counters.
7
8use super::TransactionCounterTrait;
9use crate::models::RepositoryError;
10use crate::repositories::redis_base::RedisRepository;
11use async_trait::async_trait;
12use redis::aio::ConnectionManager;
13use redis::AsyncCommands;
14use std::fmt;
15use std::sync::Arc;
16use tracing::debug;
17
18const COUNTER_PREFIX: &str = "transaction_counter";
19
20#[derive(Clone)]
21pub struct RedisTransactionCounter {
22    pub client: Arc<ConnectionManager>,
23    pub key_prefix: String,
24}
25
26impl RedisRepository for RedisTransactionCounter {}
27
28impl fmt::Debug for RedisTransactionCounter {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.debug_struct("RedisTransactionCounter")
31            .field("key_prefix", &self.key_prefix)
32            .finish()
33    }
34}
35
36impl RedisTransactionCounter {
37    pub fn new(
38        connection_manager: Arc<ConnectionManager>,
39        key_prefix: String,
40    ) -> Result<Self, RepositoryError> {
41        if key_prefix.is_empty() {
42            return Err(RepositoryError::InvalidData(
43                "Redis key prefix cannot be empty".to_string(),
44            ));
45        }
46
47        Ok(Self {
48            client: connection_manager,
49            key_prefix,
50        })
51    }
52
53    /// Generate key for transaction counter: {prefix}:transaction_counter:{relayer_id}:{address}
54    fn counter_key(&self, relayer_id: &str, address: &str) -> String {
55        format!(
56            "{}:{}:{}:{}",
57            self.key_prefix, COUNTER_PREFIX, relayer_id, address
58        )
59    }
60}
61
62#[async_trait]
63impl TransactionCounterTrait for RedisTransactionCounter {
64    async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
65        if relayer_id.is_empty() {
66            return Err(RepositoryError::InvalidData(
67                "Relayer ID cannot be empty".to_string(),
68            ));
69        }
70
71        if address.is_empty() {
72            return Err(RepositoryError::InvalidData(
73                "Address cannot be empty".to_string(),
74            ));
75        }
76
77        let key = self.counter_key(relayer_id, address);
78        debug!(relayer_id = %relayer_id, address = %address, "getting counter for relayer and address");
79
80        let mut conn = self.client.as_ref().clone();
81
82        let value: Option<u64> = conn
83            .get(&key)
84            .await
85            .map_err(|e| self.map_redis_error(e, "get_counter"))?;
86
87        debug!(value = ?value, "retrieved counter value");
88        Ok(value)
89    }
90
91    async fn get_and_increment(
92        &self,
93        relayer_id: &str,
94        address: &str,
95    ) -> Result<u64, RepositoryError> {
96        if relayer_id.is_empty() {
97            return Err(RepositoryError::InvalidData(
98                "Relayer ID cannot be empty".to_string(),
99            ));
100        }
101
102        if address.is_empty() {
103            return Err(RepositoryError::InvalidData(
104                "Address cannot be empty".to_string(),
105            ));
106        }
107
108        let key = self.counter_key(relayer_id, address);
109        debug!(relayer_id = %relayer_id, address = %address, "getting and incrementing counter for relayer and address");
110
111        let mut conn = self.client.as_ref().clone();
112
113        // Use Redis INCR for atomic increment
114        let new_value: u64 = conn
115            .incr(&key, 1)
116            .await
117            .map_err(|e| self.map_redis_error(e, "get_and_increment"))?;
118
119        let current = new_value.saturating_sub(1);
120
121        debug!(from = %current, to = %(current + 1), "counter incremented");
122        Ok(current)
123    }
124
125    async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
126        if relayer_id.is_empty() {
127            return Err(RepositoryError::InvalidData(
128                "Relayer ID cannot be empty".to_string(),
129            ));
130        }
131
132        if address.is_empty() {
133            return Err(RepositoryError::InvalidData(
134                "Address cannot be empty".to_string(),
135            ));
136        }
137
138        let key = self.counter_key(relayer_id, address);
139        debug!(relayer_id = %relayer_id, address = %address, "decrementing counter for relayer and address");
140
141        let mut conn = self.client.as_ref().clone();
142
143        // Check if counter exists first
144        let exists: bool = conn
145            .exists(&key)
146            .await
147            .map_err(|e| self.map_redis_error(e, "check_counter_exists"))?;
148
149        if !exists {
150            return Err(RepositoryError::NotFound(format!(
151                "Counter not found for relayer {} and address {}",
152                relayer_id, address
153            )));
154        }
155
156        // Use Redis DECR and correct if it goes below 0
157        let new_value: i64 = conn
158            .decr(&key, 1)
159            .await
160            .map_err(|e| self.map_redis_error(e, "decrement_counter"))?;
161
162        let new_value = if new_value < 0 {
163            // Correct negative values back to 0
164            let _: () = conn
165                .set(&key, 0)
166                .await
167                .map_err(|e| self.map_redis_error(e, "correct_negative_counter"))?;
168            0u64
169        } else {
170            new_value as u64
171        };
172
173        debug!(new_value = %new_value, "counter decremented");
174        Ok(new_value)
175    }
176
177    async fn set(
178        &self,
179        relayer_id: &str,
180        address: &str,
181        value: u64,
182    ) -> Result<(), RepositoryError> {
183        if relayer_id.is_empty() {
184            return Err(RepositoryError::InvalidData(
185                "Relayer ID cannot be empty".to_string(),
186            ));
187        }
188
189        if address.is_empty() {
190            return Err(RepositoryError::InvalidData(
191                "Address cannot be empty".to_string(),
192            ));
193        }
194
195        let key = self.counter_key(relayer_id, address);
196        debug!(relayer_id = %relayer_id, address = %address, value = %value, "setting counter for relayer and address");
197
198        let mut conn = self.client.as_ref().clone();
199
200        let _: () = conn
201            .set(&key, value)
202            .await
203            .map_err(|e| self.map_redis_error(e, "set_counter"))?;
204
205        debug!(value = %value, "counter set");
206        Ok(())
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use redis::aio::ConnectionManager;
214    use std::sync::Arc;
215    use tokio;
216    use uuid::Uuid;
217
218    async fn setup_test_repo() -> RedisTransactionCounter {
219        let redis_url =
220            std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
221        let client = redis::Client::open(redis_url).expect("Failed to create Redis client");
222        let connection_manager = ConnectionManager::new(client)
223            .await
224            .expect("Failed to create Redis connection manager");
225
226        RedisTransactionCounter::new(Arc::new(connection_manager), "test_counter".to_string())
227            .expect("Failed to create Redis transaction counter")
228    }
229
230    #[tokio::test]
231    #[ignore = "Requires active Redis instance"]
232    async fn test_get_nonexistent_counter() {
233        let repo = setup_test_repo().await;
234        let random_id = Uuid::new_v4().to_string();
235        let result = repo.get(&random_id, "0x1234").await.unwrap();
236        assert_eq!(result, None);
237    }
238
239    #[tokio::test]
240    #[ignore = "Requires active Redis instance"]
241    async fn test_set_and_get_counter() {
242        let repo = setup_test_repo().await;
243        let relayer_id = uuid::Uuid::new_v4().to_string();
244        let address = uuid::Uuid::new_v4().to_string();
245
246        repo.set(&relayer_id, &address, 100).await.unwrap();
247        let result = repo.get(&relayer_id, &address).await.unwrap();
248        assert_eq!(result, Some(100));
249    }
250
251    #[tokio::test]
252    #[ignore = "Requires active Redis instance"]
253    async fn test_get_and_increment() {
254        let repo = setup_test_repo().await;
255        let relayer_id = uuid::Uuid::new_v4().to_string();
256        let address = uuid::Uuid::new_v4().to_string();
257
258        // First increment should return 0 and set to 1
259        let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
260        assert_eq!(result, 0);
261
262        let current = repo.get(&relayer_id, &address).await.unwrap();
263        assert_eq!(current, Some(1));
264
265        // Second increment should return 1 and set to 2
266        let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
267        assert_eq!(result, 1);
268
269        let current = repo.get(&relayer_id, &address).await.unwrap();
270        assert_eq!(current, Some(2));
271    }
272
273    #[tokio::test]
274    #[ignore = "Requires active Redis instance"]
275    async fn test_decrement() {
276        let repo = setup_test_repo().await;
277        let relayer_id = uuid::Uuid::new_v4().to_string();
278        let address = uuid::Uuid::new_v4().to_string();
279
280        // Set initial value
281        repo.set(&relayer_id, &address, 5).await.unwrap();
282
283        // Decrement should return 4
284        let result = repo.decrement(&relayer_id, &address).await.unwrap();
285        assert_eq!(result, 4);
286
287        let current = repo.get(&relayer_id, &address).await.unwrap();
288        assert_eq!(current, Some(4));
289    }
290
291    #[tokio::test]
292    #[ignore = "Requires active Redis instance"]
293    async fn test_decrement_not_found() {
294        let repo = setup_test_repo().await;
295        let result = repo.decrement("nonexistent", "0x1234").await;
296        assert!(matches!(result, Err(RepositoryError::NotFound(_))));
297    }
298
299    #[tokio::test]
300    #[ignore = "Requires active Redis instance"]
301    async fn test_empty_validation() {
302        let repo = setup_test_repo().await;
303
304        // Test empty relayer_id
305        let result = repo.get("", "0x1234").await;
306        assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
307
308        // Test empty address
309        let result = repo.get("relayer", "").await;
310        assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
311    }
312
313    #[tokio::test]
314    #[ignore = "Requires active Redis instance"]
315    async fn test_multiple_relayers() {
316        let repo = setup_test_repo().await;
317        let relayer_1 = uuid::Uuid::new_v4().to_string();
318        let relayer_2 = uuid::Uuid::new_v4().to_string();
319        let address_1 = uuid::Uuid::new_v4().to_string();
320        let address_2 = uuid::Uuid::new_v4().to_string();
321
322        // Set different values for different relayer/address combinations
323        repo.set(&relayer_1, &address_1, 100).await.unwrap();
324        repo.set(&relayer_1, &address_2, 200).await.unwrap();
325        repo.set(&relayer_2, &address_1, 300).await.unwrap();
326
327        // Verify independent counters
328        assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
329        assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
330        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
331
332        // Verify independent increments
333        assert_eq!(
334            repo.get_and_increment(&relayer_1, &address_1)
335                .await
336                .unwrap(),
337            100
338        );
339        assert_eq!(
340            repo.get_and_increment(&relayer_1, &address_1)
341                .await
342                .unwrap(),
343            101
344        );
345        assert_eq!(
346            repo.get_and_increment(&relayer_1, &address_2)
347                .await
348                .unwrap(),
349            200
350        );
351        assert_eq!(
352            repo.get_and_increment(&relayer_1, &address_2)
353                .await
354                .unwrap(),
355            201
356        );
357        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
358    }
359
360    #[tokio::test]
361    #[ignore = "Requires active Redis instance"]
362    async fn test_concurrent_get_and_increment() {
363        let repo = setup_test_repo().await;
364        let relayer_id = uuid::Uuid::new_v4().to_string();
365        let address = uuid::Uuid::new_v4().to_string();
366
367        // Set initial value
368        repo.set(&relayer_id, &address, 100).await.unwrap();
369
370        // Create multiple concurrent tasks that increment the counter
371        let handles: Vec<_> = (0..10)
372            .map(|_| {
373                let repo = repo.clone();
374                let relayer_id = relayer_id.clone();
375                let address = address.clone();
376                tokio::spawn(
377                    async move { repo.get_and_increment(&relayer_id, &address).await.unwrap() },
378                )
379            })
380            .collect();
381
382        // Wait for all tasks to complete and collect results
383        let mut results = Vec::new();
384        for handle in handles {
385            results.push(handle.await.unwrap());
386        }
387
388        // Sort results to check they are sequential
389        results.sort();
390
391        // Verify we get exactly the values 100-109 (no duplicates, no gaps)
392        let expected: Vec<u64> = (100..110).collect();
393        assert_eq!(results, expected);
394
395        // Verify final value is 110
396        let final_value = repo.get(&relayer_id, &address).await.unwrap();
397        assert_eq!(final_value, Some(110));
398    }
399}