Merge pull request #20 from bartvdbraak/refactor/opt-error-handling

Control concurrency and improve error handling
This commit is contained in:
Bart van der Braak 2023-11-10 01:20:00 +01:00 committed by GitHub
commit 78b1216915
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 75 deletions

1
Cargo.lock generated
View file

@ -854,6 +854,7 @@ dependencies = [
name = "keyweave"
version = "0.2.2"
dependencies = [
"anyhow",
"azure_identity",
"azure_security_keyvault",
"clap",

View file

@ -10,6 +10,7 @@ documentation = "https://docs.rs/keyweave"
repository = "https://github.com/bartvdbraak/keyweave/"
[dependencies]
anyhow = "1.0.75"
azure_identity = "0.17.0"
azure_security_keyvault = "0.17.0"
clap = { version = "4.4.7", features = ["derive"] }

View file

@ -1,53 +1,55 @@
use anyhow::{Context, Result};
use azure_identity::DefaultAzureCredential;
use azure_security_keyvault::prelude::KeyVaultGetSecretsResponse;
use azure_security_keyvault::KeyvaultClient;
use clap::Parser;
use futures::stream::StreamExt;
use std::fs::File;
use std::io::Write;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Semaphore;
#[derive(Parser)]
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Opts {
#[clap(
short,
long,
value_name = "VAULT_NAME",
help = "Sets the name of the Azure Key Vault"
)]
/// Sets the name of the Azure Key Vault
#[clap(short, long, value_name = "VAULT_NAME")]
vault_name: String,
#[clap(
short,
long,
value_name = "FILE",
default_value = ".env",
help = "Sets the name of the output file"
)]
/// Sets the name of the output file
#[clap(short, long, value_name = "FILE", default_value = ".env")]
output: String,
#[clap(
short,
long,
value_name = "FILTER",
help = "Filters the secrets to be retrieved by name"
)]
/// Filters the secrets to be retrieved by name
#[clap(short, long, value_name = "FILTER")]
filter: Option<String>,
}
async fn fetch_secrets_from_key_vault(
vault_url: &str,
client: &KeyvaultClient,
filter: Option<&str>,
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error>> {
let credential = DefaultAzureCredential::default();
let client = KeyvaultClient::new(vault_url, std::sync::Arc::new(credential))?.secret_client();
) -> Result<Vec<(String, String)>> {
let mut secret_values = Vec::new();
let mut secret_pages = client.list_secrets().into_stream();
let mut secret_pages = client.secret_client().list_secrets().into_stream();
while let Some(page) = secret_pages.next().await {
let page = page?;
let (tx, mut rx) = mpsc::channel(32); // Channel for concurrent secret retrieval
let page = page.context("Failed to fetch secrets page")?;
secret_values
.extend(fetch_secrets_from_page(&client.secret_client(), &page, filter).await?);
}
Ok(secret_values)
}
async fn fetch_secrets_from_page(
client: &azure_security_keyvault::SecretClient,
page: &KeyVaultGetSecretsResponse,
filter: Option<&str>,
) -> Result<Vec<(String, String)>> {
let (tx, mut rx) = mpsc::channel(32);
let semaphore = Arc::new(Semaphore::new(10));
let mut handles = Vec::new();
for secret in &page.value {
if let Some(filter) = filter {
@ -55,67 +57,78 @@ async fn fetch_secrets_from_key_vault(
continue;
}
}
let tx = tx.clone();
// Clone necessary data before moving into the spawned task
let permit = semaphore.clone().acquire_owned().await.unwrap();
let tx = tx.clone();
let secret_id = secret.id.clone();
let client_clone = client.clone();
tokio::spawn(async move {
let secret_name = secret_id.split('/').last().unwrap_or_default();
let secret_bundle = client_clone.get(secret_name).await;
handles.push(tokio::spawn(async move {
let _permit = permit;
fetch_and_send_secret(client_clone, secret_id, tx).await
}));
}
// Handle the result and send it through the channel
match secret_bundle {
drop(tx);
let mut secrets = Vec::new();
for handle in handles {
if let Ok(result) = handle.await {
secrets.push(result);
}
}
while let Some(result) = rx.recv().await {
secrets.push(result);
}
Ok(secrets)
}
async fn fetch_and_send_secret(
client: azure_security_keyvault::SecretClient,
secret_id: String,
tx: mpsc::Sender<(String, String)>,
) -> (String, String) {
let secret_name = secret_id.split('/').last().unwrap_or_default();
match client.get(secret_name).await {
Ok(bundle) => {
tx.send((secret_id, bundle.value))
.await
.expect("Send error");
let _ = tx.send((secret_id.clone(), bundle.value.clone())).await;
(secret_id, bundle.value)
}
Err(err) => {
eprintln!("Error fetching secret: {}", err);
// You can decide to continue or not in case of an error.
(secret_id, String::new())
}
}
});
}
drop(tx); // Drop the sender to signal the end of sending tasks
while let Some(result) = rx.recv().await {
let (key, value) = result;
secret_values.push((key, value));
}
}
Ok(secret_values)
}
fn create_env_file(secrets: Vec<(String, String)>, output_file: &str) -> std::io::Result<()> {
let mut file = File::create(output_file)?;
fn create_env_file(secrets: Vec<(String, String)>, output_file: &str) -> Result<()> {
let mut file = File::create(output_file).context("Failed to create output file")?;
for (key, value) in secrets {
// Extract the secret name from the URL
if let Some(secret_name) = key.split('/').last() {
writeln!(file, "{}={}", secret_name, value)?;
writeln!(file, "{}={}", secret_name, value)
.with_context(|| format!("Failed to write to output file: {}", output_file))?;
}
}
Ok(())
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
async fn main() -> Result<()> {
let opts: Opts = Opts::parse();
let vault_url = format!("https://{}.vault.azure.net", opts.vault_name);
println!("Fetching secrets from Key Vault: {}", opts.vault_name);
let secrets = fetch_secrets_from_key_vault(&vault_url, opts.filter.as_deref()).await?;
let credential = DefaultAzureCredential::default();
let client = KeyvaultClient::new(&vault_url, std::sync::Arc::new(credential))
.context("Failed to create KeyvaultClient")?;
let secrets = fetch_secrets_from_key_vault(&client, opts.filter.as_deref()).await?;
println!("Creating output file: {}", opts.output);
create_env_file(secrets, &opts.output)?;
println!("Process completed successfully!");
Ok(())
}