use std::{error::Error, path::PathBuf, sync::Arc, time::Duration}; use base64::{engine::general_purpose, Engine}; use colored::Colorize; use futures::{lock::Mutex, stream, StreamExt}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use md5::{Digest, Md5}; use reqwest::{header::USER_AGENT, Client}; use tokio::{fs::File, io::AsyncWriteExt, time::timeout}; use crate::UA; const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(30); const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(60); async fn write_file(path: &PathBuf, bytes: &[u8]) -> std::io::Result { let mut file = File::create(path).await?; let bw = file.write(bytes).await?; file.flush().await?; Ok(bw) } pub async fn concurrent_dl( images: Vec<(String, PathBuf, String)>, display_url: String, ) -> Result<(usize, usize), Box> { let multi = MultiProgress::new(); let bar = multi.add(ProgressBar::new(images.len() as u64)); bar.set_style( ProgressStyle::with_template("{spinner} {msg} [{wide_bar:.white/gray}] [{pos}/{len}]") .unwrap(), ); bar.set_message(display_url); let error_bar = multi.add(ProgressBar::new(0)); error_bar.set_style(ProgressStyle::with_template("{wide_msg}").unwrap()); let dl_count = Arc::new(Mutex::new(0)); let sk_count = Arc::new(Mutex::new(0)); let client = Client::builder() .pool_idle_timeout(KEEP_ALIVE_TIMEOUT) .build()?; let futures = stream::iter(images.iter().map(|data| async { let dl_count = Arc::clone(&dl_count); let _sk_count = Arc::clone(&sk_count); let client = client.clone(); let (url, path, _expct_md5) = data; let download_result = timeout(DOWNLOAD_TIMEOUT, async { let res = client.get(url).header(USER_AGENT, UA).send().await?; let bytes = res.bytes().await?; Ok::<_, reqwest::Error>(bytes) }) .await; match download_result { Ok(Ok(bytes)) => { let mut hasher = Md5::new(); hasher.update(&bytes); let result = hasher.finalize(); let _b64_md5 = general_purpose::STANDARD.encode(result); // TODO: Figure out how the MD5 should be converted before uncommenting the following filtering condition // if b64_md5 != *expct_md5 { // error_bar.set_message(format!("File skipped due to mismatched MD5 (expected {expct_md5}, got {b64_md5})").red().bold()); // let mut sk_count = sk_count.lock().await; // *sk_count += 1; // return; // } let _n = write_file(path, &bytes).await.unwrap(); let mut dl_count = dl_count.lock().await; *dl_count += 1; } Err(_) => { error_bar.set_message(format!( "{}", format!("Failed to convert request from {} to bytes", url) .red() .bold() )); } Ok(Err(_)) => { error_bar.set_message(format!( "{}", format!("Failed to request {}", url).red().bold() )); } } bar.inc(1); })) .buffer_unordered(100) .collect::>(); futures.await; bar.finish(); error_bar.finish(); let dl = *dl_count.lock().await; let sk = *sk_count.lock().await; Ok((dl, sk)) }