mirror of
https://github.com/bartvdbraak/keyweave.git
synced 2025-04-28 07:11:21 +00:00
Merge pull request #20 from bartvdbraak/refactor/opt-error-handling
Control concurrency and improve error handling
This commit is contained in:
commit
78b1216915
3 changed files with 90 additions and 75 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -854,6 +854,7 @@ dependencies = [
|
||||||
name = "keyweave"
|
name = "keyweave"
|
||||||
version = "0.2.2"
|
version = "0.2.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"azure_identity",
|
"azure_identity",
|
||||||
"azure_security_keyvault",
|
"azure_security_keyvault",
|
||||||
"clap",
|
"clap",
|
||||||
|
|
|
@ -10,6 +10,7 @@ documentation = "https://docs.rs/keyweave"
|
||||||
repository = "https://github.com/bartvdbraak/keyweave/"
|
repository = "https://github.com/bartvdbraak/keyweave/"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1.0.75"
|
||||||
azure_identity = "0.17.0"
|
azure_identity = "0.17.0"
|
||||||
azure_security_keyvault = "0.17.0"
|
azure_security_keyvault = "0.17.0"
|
||||||
clap = { version = "4.4.7", features = ["derive"] }
|
clap = { version = "4.4.7", features = ["derive"] }
|
||||||
|
@ -17,4 +18,4 @@ futures = "0.3.29"
|
||||||
tokio = {version = "1.34.0", features = ["full"]}
|
tokio = {version = "1.34.0", features = ["full"]}
|
||||||
|
|
||||||
[target.'cfg(all(target_os = "linux", any(target_env = "musl", target_arch = "arm", target_arch = "aarch64")))'.dependencies]
|
[target.'cfg(all(target_os = "linux", any(target_env = "musl", target_arch = "arm", target_arch = "aarch64")))'.dependencies]
|
||||||
openssl = { version = "0.10", features = ["vendored"] }
|
openssl = { version = "0.10", features = ["vendored"] }
|
||||||
|
|
161
src/main.rs
161
src/main.rs
|
@ -1,121 +1,134 @@
|
||||||
|
use anyhow::{Context, Result};
|
||||||
use azure_identity::DefaultAzureCredential;
|
use azure_identity::DefaultAzureCredential;
|
||||||
|
use azure_security_keyvault::prelude::KeyVaultGetSecretsResponse;
|
||||||
use azure_security_keyvault::KeyvaultClient;
|
use azure_security_keyvault::KeyvaultClient;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use std::sync::Arc;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Opts {
|
struct Opts {
|
||||||
#[clap(
|
/// Sets the name of the Azure Key Vault
|
||||||
short,
|
#[clap(short, long, value_name = "VAULT_NAME")]
|
||||||
long,
|
|
||||||
value_name = "VAULT_NAME",
|
|
||||||
help = "Sets the name of the Azure Key Vault"
|
|
||||||
)]
|
|
||||||
vault_name: String,
|
vault_name: String,
|
||||||
|
|
||||||
#[clap(
|
/// Sets the name of the output file
|
||||||
short,
|
#[clap(short, long, value_name = "FILE", default_value = ".env")]
|
||||||
long,
|
|
||||||
value_name = "FILE",
|
|
||||||
default_value = ".env",
|
|
||||||
help = "Sets the name of the output file"
|
|
||||||
)]
|
|
||||||
output: String,
|
output: String,
|
||||||
|
|
||||||
#[clap(
|
/// Filters the secrets to be retrieved by name
|
||||||
short,
|
#[clap(short, long, value_name = "FILTER")]
|
||||||
long,
|
|
||||||
value_name = "FILTER",
|
|
||||||
help = "Filters the secrets to be retrieved by name"
|
|
||||||
)]
|
|
||||||
filter: Option<String>,
|
filter: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn fetch_secrets_from_key_vault(
|
async fn fetch_secrets_from_key_vault(
|
||||||
vault_url: &str,
|
client: &KeyvaultClient,
|
||||||
filter: Option<&str>,
|
filter: Option<&str>,
|
||||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error>> {
|
) -> Result<Vec<(String, String)>> {
|
||||||
let credential = DefaultAzureCredential::default();
|
|
||||||
let client = KeyvaultClient::new(vault_url, std::sync::Arc::new(credential))?.secret_client();
|
|
||||||
|
|
||||||
let mut secret_values = Vec::new();
|
let mut secret_values = Vec::new();
|
||||||
let mut secret_pages = client.list_secrets().into_stream();
|
let mut secret_pages = client.secret_client().list_secrets().into_stream();
|
||||||
|
|
||||||
while let Some(page) = secret_pages.next().await {
|
while let Some(page) = secret_pages.next().await {
|
||||||
let page = page?;
|
let page = page.context("Failed to fetch secrets page")?;
|
||||||
let (tx, mut rx) = mpsc::channel(32); // Channel for concurrent secret retrieval
|
secret_values
|
||||||
|
.extend(fetch_secrets_from_page(&client.secret_client(), &page, filter).await?);
|
||||||
for secret in &page.value {
|
|
||||||
if let Some(filter) = filter {
|
|
||||||
if !secret.id.contains(filter) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let tx = tx.clone();
|
|
||||||
|
|
||||||
// Clone necessary data before moving into the spawned task
|
|
||||||
let secret_id = secret.id.clone();
|
|
||||||
let client_clone = client.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let secret_name = secret_id.split('/').last().unwrap_or_default();
|
|
||||||
let secret_bundle = client_clone.get(secret_name).await;
|
|
||||||
|
|
||||||
// Handle the result and send it through the channel
|
|
||||||
match secret_bundle {
|
|
||||||
Ok(bundle) => {
|
|
||||||
tx.send((secret_id, bundle.value))
|
|
||||||
.await
|
|
||||||
.expect("Send error");
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
eprintln!("Error fetching secret: {}", err);
|
|
||||||
// You can decide to continue or not in case of an error.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
drop(tx); // Drop the sender to signal the end of sending tasks
|
|
||||||
|
|
||||||
while let Some(result) = rx.recv().await {
|
|
||||||
let (key, value) = result;
|
|
||||||
secret_values.push((key, value));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(secret_values)
|
Ok(secret_values)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_env_file(secrets: Vec<(String, String)>, output_file: &str) -> std::io::Result<()> {
|
async fn fetch_secrets_from_page(
|
||||||
let mut file = File::create(output_file)?;
|
client: &azure_security_keyvault::SecretClient,
|
||||||
|
page: &KeyVaultGetSecretsResponse,
|
||||||
|
filter: Option<&str>,
|
||||||
|
) -> Result<Vec<(String, String)>> {
|
||||||
|
let (tx, mut rx) = mpsc::channel(32);
|
||||||
|
let semaphore = Arc::new(Semaphore::new(10));
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
|
||||||
|
for secret in &page.value {
|
||||||
|
if let Some(filter) = filter {
|
||||||
|
if !secret.id.contains(filter) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let permit = semaphore.clone().acquire_owned().await.unwrap();
|
||||||
|
let tx = tx.clone();
|
||||||
|
let secret_id = secret.id.clone();
|
||||||
|
let client_clone = client.clone();
|
||||||
|
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
let _permit = permit;
|
||||||
|
fetch_and_send_secret(client_clone, secret_id, tx).await
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(tx);
|
||||||
|
|
||||||
|
let mut secrets = Vec::new();
|
||||||
|
for handle in handles {
|
||||||
|
if let Ok(result) = handle.await {
|
||||||
|
secrets.push(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(result) = rx.recv().await {
|
||||||
|
secrets.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(secrets)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_and_send_secret(
|
||||||
|
client: azure_security_keyvault::SecretClient,
|
||||||
|
secret_id: String,
|
||||||
|
tx: mpsc::Sender<(String, String)>,
|
||||||
|
) -> (String, String) {
|
||||||
|
let secret_name = secret_id.split('/').last().unwrap_or_default();
|
||||||
|
match client.get(secret_name).await {
|
||||||
|
Ok(bundle) => {
|
||||||
|
let _ = tx.send((secret_id.clone(), bundle.value.clone())).await;
|
||||||
|
(secret_id, bundle.value)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("Error fetching secret: {}", err);
|
||||||
|
(secret_id, String::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_env_file(secrets: Vec<(String, String)>, output_file: &str) -> Result<()> {
|
||||||
|
let mut file = File::create(output_file).context("Failed to create output file")?;
|
||||||
for (key, value) in secrets {
|
for (key, value) in secrets {
|
||||||
// Extract the secret name from the URL
|
|
||||||
if let Some(secret_name) = key.split('/').last() {
|
if let Some(secret_name) = key.split('/').last() {
|
||||||
writeln!(file, "{}={}", secret_name, value)?;
|
writeln!(file, "{}={}", secret_name, value)
|
||||||
|
.with_context(|| format!("Failed to write to output file: {}", output_file))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<()> {
|
||||||
let opts: Opts = Opts::parse();
|
let opts: Opts = Opts::parse();
|
||||||
|
|
||||||
let vault_url = format!("https://{}.vault.azure.net", opts.vault_name);
|
let vault_url = format!("https://{}.vault.azure.net", opts.vault_name);
|
||||||
|
|
||||||
println!("Fetching secrets from Key Vault: {}", opts.vault_name);
|
println!("Fetching secrets from Key Vault: {}", opts.vault_name);
|
||||||
|
|
||||||
let secrets = fetch_secrets_from_key_vault(&vault_url, opts.filter.as_deref()).await?;
|
let credential = DefaultAzureCredential::default();
|
||||||
|
let client = KeyvaultClient::new(&vault_url, std::sync::Arc::new(credential))
|
||||||
|
.context("Failed to create KeyvaultClient")?;
|
||||||
|
|
||||||
|
let secrets = fetch_secrets_from_key_vault(&client, opts.filter.as_deref()).await?;
|
||||||
println!("Creating output file: {}", opts.output);
|
println!("Creating output file: {}", opts.output);
|
||||||
create_env_file(secrets, &opts.output)?;
|
create_env_file(secrets, &opts.output)?;
|
||||||
|
|
||||||
println!("Process completed successfully!");
|
println!("Process completed successfully!");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue