Merge pull request #20 from bartvdbraak/refactor/opt-error-handling

Control concurrency and improve error handling
This commit is contained in:
Bart van der Braak 2023-11-10 01:20:00 +01:00 committed by GitHub
commit 78b1216915
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 75 deletions

1
Cargo.lock generated
View file

@ -854,6 +854,7 @@ dependencies = [
name = "keyweave" name = "keyweave"
version = "0.2.2" version = "0.2.2"
dependencies = [ dependencies = [
"anyhow",
"azure_identity", "azure_identity",
"azure_security_keyvault", "azure_security_keyvault",
"clap", "clap",

View file

@ -10,6 +10,7 @@ documentation = "https://docs.rs/keyweave"
repository = "https://github.com/bartvdbraak/keyweave/" repository = "https://github.com/bartvdbraak/keyweave/"
[dependencies] [dependencies]
anyhow = "1.0.75"
azure_identity = "0.17.0" azure_identity = "0.17.0"
azure_security_keyvault = "0.17.0" azure_security_keyvault = "0.17.0"
clap = { version = "4.4.7", features = ["derive"] } clap = { version = "4.4.7", features = ["derive"] }
@ -17,4 +18,4 @@ futures = "0.3.29"
tokio = {version = "1.34.0", features = ["full"]} tokio = {version = "1.34.0", features = ["full"]}
[target.'cfg(all(target_os = "linux", any(target_env = "musl", target_arch = "arm", target_arch = "aarch64")))'.dependencies] [target.'cfg(all(target_os = "linux", any(target_env = "musl", target_arch = "arm", target_arch = "aarch64")))'.dependencies]
openssl = { version = "0.10", features = ["vendored"] } openssl = { version = "0.10", features = ["vendored"] }

View file

@ -1,121 +1,134 @@
use anyhow::{Context, Result};
use azure_identity::DefaultAzureCredential; use azure_identity::DefaultAzureCredential;
use azure_security_keyvault::prelude::KeyVaultGetSecretsResponse;
use azure_security_keyvault::KeyvaultClient; use azure_security_keyvault::KeyvaultClient;
use clap::Parser; use clap::Parser;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::Semaphore;
#[derive(Parser)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
struct Opts { struct Opts {
#[clap( /// Sets the name of the Azure Key Vault
short, #[clap(short, long, value_name = "VAULT_NAME")]
long,
value_name = "VAULT_NAME",
help = "Sets the name of the Azure Key Vault"
)]
vault_name: String, vault_name: String,
#[clap( /// Sets the name of the output file
short, #[clap(short, long, value_name = "FILE", default_value = ".env")]
long,
value_name = "FILE",
default_value = ".env",
help = "Sets the name of the output file"
)]
output: String, output: String,
#[clap( /// Filters the secrets to be retrieved by name
short, #[clap(short, long, value_name = "FILTER")]
long,
value_name = "FILTER",
help = "Filters the secrets to be retrieved by name"
)]
filter: Option<String>, filter: Option<String>,
} }
async fn fetch_secrets_from_key_vault( async fn fetch_secrets_from_key_vault(
vault_url: &str, client: &KeyvaultClient,
filter: Option<&str>, filter: Option<&str>,
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error>> { ) -> Result<Vec<(String, String)>> {
let credential = DefaultAzureCredential::default();
let client = KeyvaultClient::new(vault_url, std::sync::Arc::new(credential))?.secret_client();
let mut secret_values = Vec::new(); let mut secret_values = Vec::new();
let mut secret_pages = client.list_secrets().into_stream(); let mut secret_pages = client.secret_client().list_secrets().into_stream();
while let Some(page) = secret_pages.next().await { while let Some(page) = secret_pages.next().await {
let page = page?; let page = page.context("Failed to fetch secrets page")?;
let (tx, mut rx) = mpsc::channel(32); // Channel for concurrent secret retrieval secret_values
.extend(fetch_secrets_from_page(&client.secret_client(), &page, filter).await?);
for secret in &page.value {
if let Some(filter) = filter {
if !secret.id.contains(filter) {
continue;
}
}
let tx = tx.clone();
// Clone necessary data before moving into the spawned task
let secret_id = secret.id.clone();
let client_clone = client.clone();
tokio::spawn(async move {
let secret_name = secret_id.split('/').last().unwrap_or_default();
let secret_bundle = client_clone.get(secret_name).await;
// Handle the result and send it through the channel
match secret_bundle {
Ok(bundle) => {
tx.send((secret_id, bundle.value))
.await
.expect("Send error");
}
Err(err) => {
eprintln!("Error fetching secret: {}", err);
// You can decide to continue or not in case of an error.
}
}
});
}
drop(tx); // Drop the sender to signal the end of sending tasks
while let Some(result) = rx.recv().await {
let (key, value) = result;
secret_values.push((key, value));
}
} }
Ok(secret_values) Ok(secret_values)
} }
fn create_env_file(secrets: Vec<(String, String)>, output_file: &str) -> std::io::Result<()> { async fn fetch_secrets_from_page(
let mut file = File::create(output_file)?; client: &azure_security_keyvault::SecretClient,
page: &KeyVaultGetSecretsResponse,
filter: Option<&str>,
) -> Result<Vec<(String, String)>> {
let (tx, mut rx) = mpsc::channel(32);
let semaphore = Arc::new(Semaphore::new(10));
let mut handles = Vec::new();
for secret in &page.value {
if let Some(filter) = filter {
if !secret.id.contains(filter) {
continue;
}
}
let permit = semaphore.clone().acquire_owned().await.unwrap();
let tx = tx.clone();
let secret_id = secret.id.clone();
let client_clone = client.clone();
handles.push(tokio::spawn(async move {
let _permit = permit;
fetch_and_send_secret(client_clone, secret_id, tx).await
}));
}
drop(tx);
let mut secrets = Vec::new();
for handle in handles {
if let Ok(result) = handle.await {
secrets.push(result);
}
}
while let Some(result) = rx.recv().await {
secrets.push(result);
}
Ok(secrets)
}
async fn fetch_and_send_secret(
client: azure_security_keyvault::SecretClient,
secret_id: String,
tx: mpsc::Sender<(String, String)>,
) -> (String, String) {
let secret_name = secret_id.split('/').last().unwrap_or_default();
match client.get(secret_name).await {
Ok(bundle) => {
let _ = tx.send((secret_id.clone(), bundle.value.clone())).await;
(secret_id, bundle.value)
}
Err(err) => {
eprintln!("Error fetching secret: {}", err);
(secret_id, String::new())
}
}
}
fn create_env_file(secrets: Vec<(String, String)>, output_file: &str) -> Result<()> {
let mut file = File::create(output_file).context("Failed to create output file")?;
for (key, value) in secrets { for (key, value) in secrets {
// Extract the secret name from the URL
if let Some(secret_name) = key.split('/').last() { if let Some(secret_name) = key.split('/').last() {
writeln!(file, "{}={}", secret_name, value)?; writeln!(file, "{}={}", secret_name, value)
.with_context(|| format!("Failed to write to output file: {}", output_file))?;
} }
} }
Ok(()) Ok(())
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<()> {
let opts: Opts = Opts::parse(); let opts: Opts = Opts::parse();
let vault_url = format!("https://{}.vault.azure.net", opts.vault_name); let vault_url = format!("https://{}.vault.azure.net", opts.vault_name);
println!("Fetching secrets from Key Vault: {}", opts.vault_name); println!("Fetching secrets from Key Vault: {}", opts.vault_name);
let secrets = fetch_secrets_from_key_vault(&vault_url, opts.filter.as_deref()).await?; let credential = DefaultAzureCredential::default();
let client = KeyvaultClient::new(&vault_url, std::sync::Arc::new(credential))
.context("Failed to create KeyvaultClient")?;
let secrets = fetch_secrets_from_key_vault(&client, opts.filter.as_deref()).await?;
println!("Creating output file: {}", opts.output); println!("Creating output file: {}", opts.output);
create_env_file(secrets, &opts.output)?; create_env_file(secrets, &opts.output)?;
println!("Process completed successfully!"); println!("Process completed successfully!");
Ok(()) Ok(())
} }