113 lines
3.5 KiB
Rust
113 lines
3.5 KiB
Rust
use std::{error::Error, path::PathBuf, sync::Arc, time::Duration};
|
|
|
|
use base64::{engine::general_purpose, Engine};
|
|
use colored::Colorize;
|
|
use futures::{lock::Mutex, stream, StreamExt};
|
|
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
|
use md5::{Digest, Md5};
|
|
use reqwest::{header::USER_AGENT, Client};
|
|
use tokio::{fs::File, io::AsyncWriteExt, time::timeout};
|
|
|
|
use crate::UA;
|
|
|
|
const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(30);
|
|
const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(60);
|
|
|
|
async fn write_file(path: &PathBuf, bytes: &[u8]) -> std::io::Result<usize> {
|
|
let mut file = File::create(path).await?;
|
|
let bw = file.write(bytes).await?;
|
|
file.flush().await?;
|
|
|
|
Ok(bw)
|
|
}
|
|
|
|
pub async fn concurrent_dl(
|
|
images: Vec<(String, PathBuf, String)>,
|
|
display_url: String,
|
|
) -> Result<(usize, usize), Box<dyn Error>> {
|
|
let multi = MultiProgress::new();
|
|
|
|
let bar = multi.add(ProgressBar::new(images.len() as u64));
|
|
bar.set_style(
|
|
ProgressStyle::with_template("{spinner} {msg} [{wide_bar:.white/gray}] [{pos}/{len}]")
|
|
.unwrap(),
|
|
);
|
|
bar.set_message(display_url);
|
|
|
|
let error_bar = multi.add(ProgressBar::new(0));
|
|
error_bar.set_style(ProgressStyle::with_template("{wide_msg}").unwrap());
|
|
|
|
let dl_count = Arc::new(Mutex::new(0));
|
|
let sk_count = Arc::new(Mutex::new(0));
|
|
|
|
let client = Client::builder()
|
|
.pool_idle_timeout(KEEP_ALIVE_TIMEOUT)
|
|
.build()?;
|
|
|
|
let futures = stream::iter(images.iter().map(|data| async {
|
|
let dl_count = Arc::clone(&dl_count);
|
|
let _sk_count = Arc::clone(&sk_count);
|
|
let client = client.clone();
|
|
|
|
let (url, path, _expct_md5) = data;
|
|
|
|
let download_result = timeout(DOWNLOAD_TIMEOUT, async {
|
|
let res = client.get(url).header(USER_AGENT, UA).send().await?;
|
|
let bytes = res.bytes().await?;
|
|
|
|
Ok::<_, reqwest::Error>(bytes)
|
|
})
|
|
.await;
|
|
|
|
match download_result {
|
|
Ok(Ok(bytes)) => {
|
|
let mut hasher = Md5::new();
|
|
hasher.update(&bytes);
|
|
let result = hasher.finalize();
|
|
let _b64_md5 = general_purpose::STANDARD.encode(result);
|
|
|
|
// TODO: Figure out how the MD5 should be converted before uncommenting the following filtering condition
|
|
// if b64_md5 != *expct_md5 {
|
|
// error_bar.set_message(format!("File skipped due to mismatched MD5 (expected {expct_md5}, got {b64_md5})").red().bold());
|
|
// let mut sk_count = sk_count.lock().await;
|
|
// *sk_count += 1;
|
|
|
|
// return;
|
|
// }
|
|
|
|
let _n = write_file(path, &bytes).await.unwrap();
|
|
let mut dl_count = dl_count.lock().await;
|
|
*dl_count += 1;
|
|
}
|
|
Err(_) => {
|
|
error_bar.set_message(format!(
|
|
"{}",
|
|
format!("Failed to convert request from {} to bytes", url)
|
|
.red()
|
|
.bold()
|
|
));
|
|
}
|
|
Ok(Err(_)) => {
|
|
error_bar.set_message(format!(
|
|
"{}",
|
|
format!("Failed to request {}", url).red().bold()
|
|
));
|
|
}
|
|
}
|
|
|
|
bar.inc(1);
|
|
}))
|
|
.buffer_unordered(100)
|
|
.collect::<Vec<()>>();
|
|
|
|
futures.await;
|
|
|
|
bar.finish();
|
|
error_bar.finish();
|
|
|
|
let dl = *dl_count.lock().await;
|
|
let sk = *sk_count.lock().await;
|
|
|
|
Ok((dl, sk))
|
|
}
|