openzeppelin_relayer/repositories/api_key/
api_key_in_memory.rs

1//! This module provides an in-memory implementation of api keys.
2//!
3//! The `InMemoryApiKeyRepository` struct is used to store and retrieve api keys
4//! permissions.
5use crate::{
6    models::{ApiKeyRepoModel, PaginationQuery},
7    repositories::{ApiKeyRepositoryTrait, PaginatedResult, RepositoryError},
8};
9
10use async_trait::async_trait;
11
12use std::collections::HashMap;
13use tokio::sync::{Mutex, MutexGuard};
14
15#[derive(Debug)]
16pub struct InMemoryApiKeyRepository {
17    store: Mutex<HashMap<String, ApiKeyRepoModel>>,
18}
19
20impl Clone for InMemoryApiKeyRepository {
21    fn clone(&self) -> Self {
22        // Try to get the current data, or use empty HashMap if lock fails
23        let data = self
24            .store
25            .try_lock()
26            .map(|guard| guard.clone())
27            .unwrap_or_else(|_| HashMap::new());
28
29        Self {
30            store: Mutex::new(data),
31        }
32    }
33}
34
35impl InMemoryApiKeyRepository {
36    pub fn new() -> Self {
37        Self {
38            store: Mutex::new(HashMap::new()),
39        }
40    }
41
42    async fn acquire_lock<T>(lock: &Mutex<T>) -> Result<MutexGuard<T>, RepositoryError> {
43        Ok(lock.lock().await)
44    }
45}
46
47impl Default for InMemoryApiKeyRepository {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53#[async_trait]
54impl ApiKeyRepositoryTrait for InMemoryApiKeyRepository {
55    async fn create(&self, api_key: ApiKeyRepoModel) -> Result<ApiKeyRepoModel, RepositoryError> {
56        let mut store = Self::acquire_lock(&self.store).await?;
57        store.insert(api_key.id.clone(), api_key.clone());
58        Ok(api_key)
59    }
60
61    async fn get_by_id(&self, id: &str) -> Result<Option<ApiKeyRepoModel>, RepositoryError> {
62        let store = Self::acquire_lock(&self.store).await?;
63        Ok(store.get(id).cloned())
64    }
65
66    async fn list_paginated(
67        &self,
68        query: PaginationQuery,
69    ) -> Result<PaginatedResult<ApiKeyRepoModel>, RepositoryError> {
70        let total = self.count().await?;
71        let start = ((query.page - 1) * query.per_page) as usize;
72
73        let items = self
74            .store
75            .lock()
76            .await
77            .values()
78            .skip(start)
79            .take(query.per_page as usize)
80            .cloned()
81            .collect();
82
83        Ok(PaginatedResult {
84            items,
85            total: total as u64,
86            page: query.page,
87            per_page: query.per_page,
88        })
89    }
90
91    async fn count(&self) -> Result<usize, RepositoryError> {
92        let store = self.store.lock().await;
93        Ok(store.len())
94    }
95
96    async fn list_permissions(&self, api_key_id: &str) -> Result<Vec<String>, RepositoryError> {
97        let store = self.store.lock().await;
98        let api_key = store
99            .get(api_key_id)
100            .ok_or(RepositoryError::NotFound(format!(
101                "Api key with id {} not found",
102                api_key_id
103            )))?;
104        Ok(api_key.permissions.clone())
105    }
106
107    async fn delete_by_id(&self, api_key_id: &str) -> Result<(), RepositoryError> {
108        let mut store = self.store.lock().await;
109        store.remove(api_key_id);
110        Ok(())
111    }
112
113    async fn has_entries(&self) -> Result<bool, RepositoryError> {
114        let store = Self::acquire_lock(&self.store).await?;
115        Ok(!store.is_empty())
116    }
117
118    async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
119        let mut store = Self::acquire_lock(&self.store).await?;
120        store.clear();
121        Ok(())
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use chrono::Utc;
128    use std::sync::Arc;
129
130    use crate::models::SecretString;
131
132    use super::*;
133
134    #[tokio::test]
135    async fn test_in_memory_api_key_repository() {
136        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
137
138        // Test add and get_by_id
139        let api_key = ApiKeyRepoModel {
140            id: "test-api-key".to_string(),
141            value: SecretString::new("test-value"),
142            name: "test-name".to_string(),
143            allowed_origins: vec!["*".to_string()],
144            permissions: vec!["relayer:all:execute".to_string()],
145            created_at: Utc::now().to_string(),
146        };
147        api_key_repository.create(api_key.clone()).await.unwrap();
148        assert_eq!(
149            api_key_repository.get_by_id("test-api-key").await.unwrap(),
150            Some(api_key)
151        );
152    }
153
154    #[tokio::test]
155    async fn test_get_nonexistent_api_key() {
156        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
157
158        let result = api_key_repository.get_by_id("test-api-key").await;
159        assert!(matches!(result, Ok(None)));
160    }
161
162    #[tokio::test]
163    async fn test_get_by_id() {
164        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
165
166        let api_key = ApiKeyRepoModel {
167            id: "test-api-key".to_string(),
168            value: SecretString::new("test-value"),
169            name: "test-name".to_string(),
170            allowed_origins: vec!["*".to_string()],
171            permissions: vec!["relayer:all:execute".to_string()],
172            created_at: Utc::now().to_string(),
173        };
174        api_key_repository.create(api_key.clone()).await.unwrap();
175        assert_eq!(
176            api_key_repository.get_by_id("test-api-key").await.unwrap(),
177            Some(api_key)
178        );
179    }
180
181    #[tokio::test]
182    async fn test_list_paginated_api_keys() {
183        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
184
185        let api_key1 = ApiKeyRepoModel {
186            id: "test-api-key1".to_string(),
187            value: SecretString::new("test-value1"),
188            name: "test-name1".to_string(),
189            allowed_origins: vec!["*".to_string()],
190            permissions: vec!["relayer:all:execute".to_string()],
191            created_at: Utc::now().to_string(),
192        };
193
194        let api_key2 = ApiKeyRepoModel {
195            id: "test-api-key2".to_string(),
196            value: SecretString::new("test-value2"),
197            name: "test-name2".to_string(),
198            allowed_origins: vec!["*".to_string()],
199            permissions: vec!["relayer:all:execute".to_string()],
200            created_at: Utc::now().to_string(),
201        };
202
203        api_key_repository.create(api_key1.clone()).await.unwrap();
204        api_key_repository.create(api_key2.clone()).await.unwrap();
205
206        let query = PaginationQuery {
207            page: 1,
208            per_page: 2,
209        };
210
211        let result = api_key_repository.list_paginated(query).await;
212        assert!(result.is_ok());
213        let result = result.unwrap();
214        assert_eq!(result.items.len(), 2);
215    }
216
217    #[tokio::test]
218    async fn test_has_entries() {
219        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
220        assert!(!api_key_repository.has_entries().await.unwrap());
221        api_key_repository
222            .create(ApiKeyRepoModel {
223                id: "test-api-key".to_string(),
224                value: SecretString::new("test-value"),
225                name: "test-name".to_string(),
226                allowed_origins: vec!["*".to_string()],
227                permissions: vec!["relayer:all:execute".to_string()],
228                created_at: Utc::now().to_string(),
229            })
230            .await
231            .unwrap();
232
233        assert!(api_key_repository.has_entries().await.unwrap());
234        api_key_repository.drop_all_entries().await.unwrap();
235        assert!(!api_key_repository.has_entries().await.unwrap());
236    }
237
238    #[tokio::test]
239    async fn test_delete_by_id_api_key() {
240        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
241        api_key_repository
242            .create(ApiKeyRepoModel {
243                id: "test-api-key".to_string(),
244                value: SecretString::new("test-value"),
245                name: "test-name".to_string(),
246                allowed_origins: vec!["*".to_string()],
247                permissions: vec!["relayer:all:execute".to_string()],
248                created_at: Utc::now().to_string(),
249            })
250            .await
251            .unwrap();
252
253        assert!(api_key_repository.has_entries().await.unwrap());
254        api_key_repository
255            .delete_by_id("test-api-key")
256            .await
257            .unwrap();
258        assert!(!api_key_repository.has_entries().await.unwrap());
259    }
260
261    #[tokio::test]
262    async fn test_list_permissions_api_key() {
263        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
264        api_key_repository
265            .create(ApiKeyRepoModel {
266                id: "test-api-key".to_string(),
267                value: SecretString::new("test-value"),
268                name: "test-name".to_string(),
269                allowed_origins: vec!["*".to_string()],
270                permissions: vec![
271                    "relayer:all:execute".to_string(),
272                    "relayer:all:read".to_string(),
273                ],
274                created_at: Utc::now().to_string(),
275            })
276            .await
277            .unwrap();
278
279        let permissions = api_key_repository
280            .list_permissions("test-api-key")
281            .await
282            .unwrap();
283        assert_eq!(permissions, vec!["relayer:all:execute", "relayer:all:read"]);
284    }
285
286    #[tokio::test]
287    async fn test_drop_all_entries() {
288        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
289        api_key_repository
290            .create(ApiKeyRepoModel {
291                id: "test-api-key".to_string(),
292                value: SecretString::new("test-value"),
293                name: "test-name".to_string(),
294                allowed_origins: vec!["*".to_string()],
295                permissions: vec!["relayer:all:execute".to_string()],
296                created_at: Utc::now().to_string(),
297            })
298            .await
299            .unwrap();
300
301        assert!(api_key_repository.has_entries().await.unwrap());
302        api_key_repository.drop_all_entries().await.unwrap();
303        assert!(!api_key_repository.has_entries().await.unwrap());
304    }
305}