1use async_trait::async_trait;
24use core::fmt;
25use once_cell::sync::Lazy;
26use serde::Serialize;
27use std::collections::HashMap;
28use std::hash::Hash;
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31use thiserror::Error;
32use tokio::sync::RwLock;
33use tracing::debug;
34use vaultrs::{
35 auth::approle::login,
36 client::{VaultClient, VaultClientSettingsBuilder},
37 kv2, transit,
38};
39use zeroize::{Zeroize, ZeroizeOnDrop};
40
41#[derive(Error, Debug, Serialize)]
42pub enum VaultError {
43 #[error("Vault client error: {0}")]
44 ClientError(String),
45
46 #[error("Secret not found: {0}")]
47 SecretNotFound(String),
48
49 #[error("Authentication failed: {0}")]
50 AuthenticationFailed(String),
51
52 #[error("Configuration error: {0}")]
53 ConfigError(String),
54
55 #[error("Signing error: {0}")]
56 SigningError(String),
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, Hash, Zeroize, ZeroizeOnDrop)]
61struct VaultCacheKey {
62 address: String,
63 role_id: String,
64 namespace: Option<String>,
65}
66
67impl fmt::Display for VaultCacheKey {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 write!(
70 f,
71 "{}|{}|{}",
72 self.address,
73 self.role_id,
74 self.namespace.as_deref().unwrap_or("")
75 )
76 }
77}
78
79struct TokenCache {
80 client: Arc<VaultClient>,
81 expiry: Instant,
82}
83
84static TOKEN_CACHE: Lazy<RwLock<HashMap<VaultCacheKey, TokenCache>>> =
86 Lazy::new(|| RwLock::new(HashMap::new()));
87
88#[cfg(test)]
89use mockall::automock;
90
91use crate::models::SecretString;
92use crate::utils::base64_encode;
93
94#[derive(Clone)]
95pub struct VaultConfig {
96 pub address: String,
97 pub namespace: Option<String>,
98 pub role_id: SecretString,
99 pub secret_id: SecretString,
100 pub mount_path: String,
101 pub token_ttl: Option<u64>,
103}
104
105impl VaultConfig {
106 pub fn new(
107 address: String,
108 role_id: SecretString,
109 secret_id: SecretString,
110 namespace: Option<String>,
111 mount_path: String,
112 token_ttl: Option<u64>,
113 ) -> Self {
114 Self {
115 address,
116 role_id,
117 secret_id,
118 namespace,
119 mount_path,
120 token_ttl,
121 }
122 }
123
124 fn cache_key(&self) -> VaultCacheKey {
125 VaultCacheKey {
126 address: self.address.clone(),
127 role_id: self.role_id.to_str().to_string(),
128 namespace: self.namespace.clone(),
129 }
130 }
131}
132
133#[async_trait]
134#[cfg_attr(test, automock)]
135pub trait VaultServiceTrait: Send + Sync {
136 async fn retrieve_secret(&self, key_name: &str) -> Result<String, VaultError>;
137 async fn sign(&self, key_name: &str, message: &[u8]) -> Result<String, VaultError>;
138}
139
140#[derive(Clone)]
141pub struct VaultService {
142 pub config: VaultConfig,
143}
144
145impl VaultService {
146 pub fn new(config: VaultConfig) -> Self {
147 Self { config }
148 }
149
150 async fn get_client(&self) -> Result<Arc<VaultClient>, VaultError> {
152 let cache_key = self.config.cache_key();
153
154 {
156 let cache = TOKEN_CACHE.read().await;
157 if let Some(cached) = cache.get(&cache_key) {
158 if Instant::now() < cached.expiry {
159 return Ok(Arc::clone(&cached.client));
160 }
161 }
162 }
163
164 let mut cache = TOKEN_CACHE.write().await;
166 if let Some(cached) = cache.get(&cache_key) {
168 if Instant::now() < cached.expiry {
169 return Ok(Arc::clone(&cached.client));
170 }
171 }
172
173 let client = self.create_authenticated_client().await?;
175
176 let ttl = Duration::from_secs(self.config.token_ttl.unwrap_or(45 * 60));
178
179 cache.insert(
181 cache_key,
182 TokenCache {
183 client: client.clone(),
184 expiry: Instant::now() + ttl,
185 },
186 );
187
188 Ok(client)
189 }
190
191 async fn create_authenticated_client(&self) -> Result<Arc<VaultClient>, VaultError> {
193 let mut auth_settings_builder = VaultClientSettingsBuilder::default();
194 let address = &self.config.address;
195 auth_settings_builder.address(address).verify(true);
196
197 if let Some(namespace) = &self.config.namespace {
198 auth_settings_builder.namespace(Some(namespace.clone()));
199 }
200
201 let auth_settings = auth_settings_builder.build().map_err(|e| {
202 VaultError::ConfigError(format!("Failed to build Vault client settings: {}", e))
203 })?;
204
205 let client = VaultClient::new(auth_settings).map_err(|e| {
206 VaultError::ConfigError(format!("Failed to create Vault client: {}", e))
207 })?;
208
209 let token = login(
210 &client,
211 "approle",
212 &self.config.role_id.to_str(),
213 &self.config.secret_id.to_str(),
214 )
215 .await
216 .map_err(|e| VaultError::AuthenticationFailed(e.to_string()))?;
217
218 let mut transit_settings_builder = VaultClientSettingsBuilder::default();
219
220 transit_settings_builder
221 .address(self.config.address.clone())
222 .token(token.client_token.clone())
223 .verify(true);
224
225 if let Some(namespace) = &self.config.namespace {
226 transit_settings_builder.namespace(Some(namespace.clone()));
227 }
228
229 let transit_settings = transit_settings_builder.build().map_err(|e| {
230 VaultError::ConfigError(format!("Failed to build Vault client settings: {}", e))
231 })?;
232
233 let client = Arc::new(VaultClient::new(transit_settings).map_err(|e| {
234 VaultError::ConfigError(format!(
235 "Failed to create authenticated Vault client: {}",
236 e
237 ))
238 })?);
239
240 Ok(client)
241 }
242}
243
244#[async_trait]
245impl VaultServiceTrait for VaultService {
246 async fn retrieve_secret(&self, key_name: &str) -> Result<String, VaultError> {
247 let client = self.get_client().await?;
248
249 let secret: serde_json::Value = kv2::read(&*client, &self.config.mount_path, key_name)
250 .await
251 .map_err(|e| VaultError::ClientError(e.to_string()))?;
252
253 let value = secret["value"]
254 .as_str()
255 .ok_or_else(|| {
256 VaultError::SecretNotFound(format!("Secret value invalid for key: {}", key_name))
257 })?
258 .to_string();
259
260 Ok(value)
261 }
262
263 async fn sign(&self, key_name: &str, message: &[u8]) -> Result<String, VaultError> {
264 let client = self.get_client().await?;
265
266 let vault_signature = transit::data::sign(
267 &*client,
268 &self.config.mount_path,
269 key_name,
270 &base64_encode(message),
271 None,
272 )
273 .await
274 .map_err(|e| VaultError::SigningError(format!("Failed to sign with Vault: {}", e)))?;
275
276 let vault_signature_str = &vault_signature.signature;
277
278 debug!(vault_signature_str = %vault_signature_str, "vault signature string");
279
280 Ok(vault_signature_str.clone())
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use mockito;
288 use serde_json::json;
289
290 #[test]
291 fn test_vault_config_new() {
292 let config = VaultConfig::new(
293 "https://vault.example.com".to_string(),
294 SecretString::new("test-role-id"),
295 SecretString::new("test-secret-id"),
296 Some("test-namespace".to_string()),
297 "test-mount-path".to_string(),
298 Some(60),
299 );
300
301 assert_eq!(config.address, "https://vault.example.com");
302 assert_eq!(config.role_id.to_str().as_str(), "test-role-id");
303 assert_eq!(config.secret_id.to_str().as_str(), "test-secret-id");
304 assert_eq!(config.namespace, Some("test-namespace".to_string()));
305 assert_eq!(config.mount_path, "test-mount-path");
306 assert_eq!(config.token_ttl, Some(60));
307 }
308
309 #[test]
310 fn test_vault_cache_key() {
311 let config1 = VaultConfig {
312 address: "https://vault1.example.com".to_string(),
313 namespace: Some("namespace1".to_string()),
314 role_id: SecretString::new("role1"),
315 secret_id: SecretString::new("secret1"),
316 mount_path: "transit".to_string(),
317 token_ttl: None,
318 };
319
320 let config2 = VaultConfig {
321 address: "https://vault1.example.com".to_string(),
322 namespace: Some("namespace1".to_string()),
323 role_id: SecretString::new("role1"),
324 secret_id: SecretString::new("secret1"),
325 mount_path: "different-mount".to_string(),
326 token_ttl: None,
327 };
328
329 let config3 = VaultConfig {
330 address: "https://vault2.example.com".to_string(),
331 namespace: Some("namespace1".to_string()),
332 role_id: SecretString::new("role1"),
333 secret_id: SecretString::new("secret1"),
334 mount_path: "transit".to_string(),
335 token_ttl: None,
336 };
337
338 assert_eq!(config1.cache_key(), config1.cache_key());
339 assert_eq!(config1.cache_key(), config2.cache_key());
340 assert_ne!(config1.cache_key(), config3.cache_key());
341 }
342
343 #[test]
344 fn test_vault_cache_key_display() {
345 let key_with_namespace = VaultCacheKey {
346 address: "https://vault.example.com".to_string(),
347 role_id: "role-123".to_string(),
348 namespace: Some("my-namespace".to_string()),
349 };
350
351 let key_without_namespace = VaultCacheKey {
352 address: "https://vault.example.com".to_string(),
353 role_id: "role-123".to_string(),
354 namespace: None,
355 };
356
357 assert_eq!(
358 key_with_namespace.to_string(),
359 "https://vault.example.com|role-123|my-namespace"
360 );
361
362 assert_eq!(
363 key_without_namespace.to_string(),
364 "https://vault.example.com|role-123|"
365 );
366 }
367
368 async fn setup_mock_approle_login(
370 mock_server: &mut mockito::ServerGuard,
371 role_id: &str,
372 secret_id: &str,
373 token: &str,
374 ) -> mockito::Mock {
375 mock_server
376 .mock("POST", "/v1/auth/approle/login")
377 .match_body(mockito::Matcher::Json(json!({
378 "role_id": role_id,
379 "secret_id": secret_id
380 })))
381 .with_status(200)
382 .with_header("content-type", "application/json")
383 .with_body(
384 serde_json::to_string(&json!({
385 "request_id": "test-request-id",
386 "lease_id": "",
387 "renewable": false,
388 "lease_duration": 0,
389 "data": null,
390 "wrap_info": null,
391 "warnings": null,
392 "auth": {
393 "client_token": token,
394 "accessor": "test-accessor",
395 "policies": ["default"],
396 "token_policies": ["default"],
397 "metadata": {
398 "role_name": "test-role"
399 },
400 "lease_duration": 3600,
401 "renewable": true,
402 "entity_id": "test-entity-id",
403 "token_type": "service",
404 "orphan": true
405 }
406 }))
407 .unwrap(),
408 )
409 .create_async()
410 .await
411 }
412
413 #[tokio::test]
414 async fn test_vault_service_auth_failure() {
415 let mut mock_server = mockito::Server::new_async().await;
416
417 let _login_mock = setup_mock_approle_login(
418 &mut mock_server,
419 "test-role-id",
420 "test-secret-id",
421 "test-token",
422 )
423 .await;
424
425 let _secret_mock = mock_server
426 .mock("GET", "/v1/test-mount/data/my-secret")
427 .match_header("X-Vault-Token", "test-token")
428 .with_status(200)
429 .with_header("content-type", "application/json")
430 .with_body(
431 serde_json::to_string(&json!({
432 "request_id": "test-request-id",
433 "lease_id": "",
434 "renewable": false,
435 "lease_duration": 0,
436 "data": {
437 "data": {
438 "value": "super-secret-value"
439 },
440 "metadata": {
441 "created_time": "2024-01-01T00:00:00Z",
442 "deletion_time": "",
443 "destroyed": false,
444 "version": 1
445 }
446 },
447 "wrap_info": null,
448 "warnings": null,
449 "auth": null
450 }))
451 .unwrap(),
452 )
453 .create_async()
454 .await;
455
456 let config = VaultConfig::new(
457 mock_server.url(),
458 SecretString::new("test-role-id-fake"),
459 SecretString::new("test-secret-id-fake"),
460 None,
461 "test-mount".to_string(),
462 Some(60),
463 );
464
465 let vault_service = VaultService::new(config);
466
467 let secret = vault_service.retrieve_secret("my-secret").await;
468
469 assert!(secret.is_err());
470
471 if let Err(e) = secret {
472 assert!(matches!(e, VaultError::AuthenticationFailed(_)));
473 assert!(e.to_string().contains("An error occurred with the request"));
474 }
475 }
476
477 #[tokio::test]
478 async fn test_vault_service_retrieve_secret_success() {
479 let mut mock_server = mockito::Server::new_async().await;
480
481 let _login_mock = setup_mock_approle_login(
482 &mut mock_server,
483 "test-role-id",
484 "test-secret-id",
485 "test-token",
486 )
487 .await;
488
489 let _secret_mock = mock_server
490 .mock(
491 "GET",
492 mockito::Matcher::Regex(r"/v1/test-mount/data/my-secret.*".to_string()),
493 )
494 .match_header("X-Vault-Token", "test-token")
495 .with_status(200)
496 .with_header("content-type", "application/json")
497 .with_body(
498 serde_json::to_string(&json!({
499 "request_id": "test-request-id",
500 "lease_id": "",
501 "renewable": false,
502 "lease_duration": 0,
503 "data": {
504 "data": {
505 "value": "super-secret-value"
506 },
507 "metadata": {
508 "created_time": "2024-01-01T00:00:00Z",
509 "deletion_time": "",
510 "destroyed": false,
511 "version": 1
512 }
513 },
514 "wrap_info": null,
515 "warnings": null,
516 "auth": null
517 }))
518 .unwrap(),
519 )
520 .create_async()
521 .await;
522
523 let config = VaultConfig::new(
524 mock_server.url(),
525 SecretString::new("test-role-id"),
526 SecretString::new("test-secret-id"),
527 None,
528 "test-mount".to_string(),
529 Some(60),
530 );
531
532 let vault_service = VaultService::new(config);
533
534 let secret = vault_service.retrieve_secret("my-secret").await.unwrap();
535
536 assert_eq!(secret, "super-secret-value");
537 }
538
539 #[tokio::test]
540 async fn test_vault_service_sign_success() {
541 let mut mock_server = mockito::Server::new_async().await;
542
543 let _login_mock = setup_mock_approle_login(
544 &mut mock_server,
545 "test-role-id",
546 "test-secret-id",
547 "test-token",
548 )
549 .await;
550
551 let message = b"hello world";
552 let encoded_message = base64_encode(message);
553
554 let _sign_mock = mock_server
555 .mock("POST", "/v1/test-mount/sign/my-signing-key")
556 .match_header("X-Vault-Token", "test-token")
557 .match_body(mockito::Matcher::Json(json!({
558 "input": encoded_message
559 })))
560 .with_status(200)
561 .with_header("content-type", "application/json")
562 .with_body(
563 serde_json::to_string(&json!({
564 "request_id": "test-request-id",
565 "lease_id": "",
566 "renewable": false,
567 "lease_duration": 0,
568 "data": {
569 "signature": "vault:v1:fake-signature",
570 "key_version": 1
571 },
572 "wrap_info": null,
573 "warnings": null,
574 "auth": null
575 }))
576 .unwrap(),
577 )
578 .create_async()
579 .await;
580
581 let config = VaultConfig::new(
582 mock_server.url(),
583 SecretString::new("test-role-id"),
584 SecretString::new("test-secret-id"),
585 None,
586 "test-mount".to_string(),
587 Some(60),
588 );
589
590 let vault_service = VaultService::new(config);
591 let signature = vault_service.sign("my-signing-key", message).await.unwrap();
592
593 assert_eq!(signature, "vault:v1:fake-signature");
594 }
595
596 #[tokio::test]
597 async fn test_vault_service_retrieve_secret_failure() {
598 let mut mock_server = mockito::Server::new_async().await;
599
600 let _login_mock = setup_mock_approle_login(
601 &mut mock_server,
602 "test-role-id",
603 "test-secret-id",
604 "test-token",
605 )
606 .await;
607
608 let _secret_mock = mock_server
609 .mock(
610 "GET",
611 mockito::Matcher::Regex(r"/v1/test-mount/data/my-secret.*".to_string()),
612 )
613 .match_header("X-Vault-Token", "test-token")
614 .with_status(404)
615 .with_header("content-type", "application/json")
616 .with_body(
617 serde_json::to_string(&json!({
618 "errors": ["secret not found:"]
619 }))
620 .unwrap(),
621 )
622 .create_async()
623 .await;
624
625 let config = VaultConfig::new(
626 mock_server.url(),
627 SecretString::new("test-role-id"),
628 SecretString::new("test-secret-id"),
629 None,
630 "test-mount".to_string(),
631 Some(60),
632 );
633
634 let vault_service = VaultService::new(config);
635
636 let result = vault_service.retrieve_secret("my-secret").await;
637 assert!(result.is_err());
638
639 if let Err(e) = result {
640 assert!(matches!(e, VaultError::ClientError(_)));
641 assert!(e
642 .to_string()
643 .contains("The Vault server returned an error (status code 404)"));
644 }
645 }
646
647 #[tokio::test]
648 async fn test_vault_service_sign_failure() {
649 let mut mock_server = mockito::Server::new_async().await;
650
651 let _login_mock = setup_mock_approle_login(
652 &mut mock_server,
653 "test-role-id",
654 "test-secret-id",
655 "test-token",
656 )
657 .await;
658
659 let message = b"hello world";
660 let encoded_message = base64_encode(message);
661
662 let _sign_mock = mock_server
663 .mock("POST", "/v1/test-mount/sign/my-signing-key")
664 .match_header("X-Vault-Token", "test-token")
665 .match_body(mockito::Matcher::Json(json!({
666 "input": encoded_message
667 })))
668 .with_status(400)
669 .with_header("content-type", "application/json")
670 .with_body(
671 serde_json::to_string(&json!({
672 "errors": ["1 error occurred:\n\t* signing key not found"]
673 }))
674 .unwrap(),
675 )
676 .create_async()
677 .await;
678
679 let config = VaultConfig::new(
680 mock_server.url(),
681 SecretString::new("test-role-id"),
682 SecretString::new("test-secret-id"),
683 None,
684 "test-mount".to_string(),
685 Some(60),
686 );
687
688 let vault_service = VaultService::new(config);
689 let result = vault_service.sign("my-signing-key", message).await;
690 assert!(result.is_err());
691
692 if let Err(e) = result {
693 assert!(matches!(e, VaultError::SigningError(_)));
694 assert!(e.to_string().contains("Failed to sign with Vault"));
695 }
696 }
697}