From b9656bf118da7f054dbab26276ecfdb858a22aa6 Mon Sep 17 00:00:00 2001 From: Bart van der Braak Date: Fri, 10 Nov 2023 01:14:19 +0100 Subject: [PATCH] refactor: optimize threads and error handling --- Cargo.lock | 1 + Cargo.toml | 3 +- src/main.rs | 161 ++++++++++++++++++++++++++++------------------------ 3 files changed, 90 insertions(+), 75 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index de438f3..9aaee82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -854,6 +854,7 @@ dependencies = [ name = "keyweave" version = "0.2.2" dependencies = [ + "anyhow", "azure_identity", "azure_security_keyvault", "clap", diff --git a/Cargo.toml b/Cargo.toml index 6cdcb7f..ad8a2ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ documentation = "https://docs.rs/keyweave" repository = "https://github.com/bartvdbraak/keyweave/" [dependencies] +anyhow = "1.0.75" azure_identity = "0.17.0" azure_security_keyvault = "0.17.0" clap = { version = "4.4.7", features = ["derive"] } @@ -17,4 +18,4 @@ futures = "0.3.29" tokio = {version = "1.34.0", features = ["full"]} [target.'cfg(all(target_os = "linux", any(target_env = "musl", target_arch = "arm", target_arch = "aarch64")))'.dependencies] -openssl = { version = "0.10", features = ["vendored"] } \ No newline at end of file +openssl = { version = "0.10", features = ["vendored"] } diff --git a/src/main.rs b/src/main.rs index 6579715..c7268e7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,121 +1,134 @@ +use anyhow::{Context, Result}; use azure_identity::DefaultAzureCredential; +use azure_security_keyvault::prelude::KeyVaultGetSecretsResponse; use azure_security_keyvault::KeyvaultClient; use clap::Parser; use futures::stream::StreamExt; use std::fs::File; use std::io::Write; +use std::sync::Arc; use tokio::sync::mpsc; +use tokio::sync::Semaphore; -#[derive(Parser)] +#[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Opts { - #[clap( - short, - long, - value_name = "VAULT_NAME", - help = "Sets the name of the Azure Key Vault" - )] + /// Sets the name of the Azure Key Vault + #[clap(short, long, value_name = "VAULT_NAME")] vault_name: String, - #[clap( - short, - long, - value_name = "FILE", - default_value = ".env", - help = "Sets the name of the output file" - )] + /// Sets the name of the output file + #[clap(short, long, value_name = "FILE", default_value = ".env")] output: String, - #[clap( - short, - long, - value_name = "FILTER", - help = "Filters the secrets to be retrieved by name" - )] + /// Filters the secrets to be retrieved by name + #[clap(short, long, value_name = "FILTER")] filter: Option, } async fn fetch_secrets_from_key_vault( - vault_url: &str, + client: &KeyvaultClient, filter: Option<&str>, -) -> Result, Box> { - let credential = DefaultAzureCredential::default(); - let client = KeyvaultClient::new(vault_url, std::sync::Arc::new(credential))?.secret_client(); - +) -> Result> { let mut secret_values = Vec::new(); - let mut secret_pages = client.list_secrets().into_stream(); + let mut secret_pages = client.secret_client().list_secrets().into_stream(); while let Some(page) = secret_pages.next().await { - let page = page?; - let (tx, mut rx) = mpsc::channel(32); // Channel for concurrent secret retrieval - - for secret in &page.value { - if let Some(filter) = filter { - if !secret.id.contains(filter) { - continue; - } - } - let tx = tx.clone(); - - // Clone necessary data before moving into the spawned task - let secret_id = secret.id.clone(); - let client_clone = client.clone(); - - tokio::spawn(async move { - let secret_name = secret_id.split('/').last().unwrap_or_default(); - let secret_bundle = client_clone.get(secret_name).await; - - // Handle the result and send it through the channel - match secret_bundle { - Ok(bundle) => { - tx.send((secret_id, bundle.value)) - .await - .expect("Send error"); - } - Err(err) => { - eprintln!("Error fetching secret: {}", err); - // You can decide to continue or not in case of an error. - } - } - }); - } - - drop(tx); // Drop the sender to signal the end of sending tasks - - while let Some(result) = rx.recv().await { - let (key, value) = result; - secret_values.push((key, value)); - } + let page = page.context("Failed to fetch secrets page")?; + secret_values + .extend(fetch_secrets_from_page(&client.secret_client(), &page, filter).await?); } Ok(secret_values) } -fn create_env_file(secrets: Vec<(String, String)>, output_file: &str) -> std::io::Result<()> { - let mut file = File::create(output_file)?; +async fn fetch_secrets_from_page( + client: &azure_security_keyvault::SecretClient, + page: &KeyVaultGetSecretsResponse, + filter: Option<&str>, +) -> Result> { + let (tx, mut rx) = mpsc::channel(32); + let semaphore = Arc::new(Semaphore::new(10)); + let mut handles = Vec::new(); + + for secret in &page.value { + if let Some(filter) = filter { + if !secret.id.contains(filter) { + continue; + } + } + + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let tx = tx.clone(); + let secret_id = secret.id.clone(); + let client_clone = client.clone(); + + handles.push(tokio::spawn(async move { + let _permit = permit; + fetch_and_send_secret(client_clone, secret_id, tx).await + })); + } + + drop(tx); + + let mut secrets = Vec::new(); + for handle in handles { + if let Ok(result) = handle.await { + secrets.push(result); + } + } + + while let Some(result) = rx.recv().await { + secrets.push(result); + } + + Ok(secrets) +} + +async fn fetch_and_send_secret( + client: azure_security_keyvault::SecretClient, + secret_id: String, + tx: mpsc::Sender<(String, String)>, +) -> (String, String) { + let secret_name = secret_id.split('/').last().unwrap_or_default(); + match client.get(secret_name).await { + Ok(bundle) => { + let _ = tx.send((secret_id.clone(), bundle.value.clone())).await; + (secret_id, bundle.value) + } + Err(err) => { + eprintln!("Error fetching secret: {}", err); + (secret_id, String::new()) + } + } +} + +fn create_env_file(secrets: Vec<(String, String)>, output_file: &str) -> Result<()> { + let mut file = File::create(output_file).context("Failed to create output file")?; for (key, value) in secrets { - // Extract the secret name from the URL if let Some(secret_name) = key.split('/').last() { - writeln!(file, "{}={}", secret_name, value)?; + writeln!(file, "{}={}", secret_name, value) + .with_context(|| format!("Failed to write to output file: {}", output_file))?; } } Ok(()) } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<()> { let opts: Opts = Opts::parse(); let vault_url = format!("https://{}.vault.azure.net", opts.vault_name); - println!("Fetching secrets from Key Vault: {}", opts.vault_name); - let secrets = fetch_secrets_from_key_vault(&vault_url, opts.filter.as_deref()).await?; + let credential = DefaultAzureCredential::default(); + let client = KeyvaultClient::new(&vault_url, std::sync::Arc::new(credential)) + .context("Failed to create KeyvaultClient")?; + let secrets = fetch_secrets_from_key_vault(&client, opts.filter.as_deref()).await?; println!("Creating output file: {}", opts.output); create_env_file(secrets, &opts.output)?; println!("Process completed successfully!"); - Ok(()) }