dlrs/src/http.rs

113 lines
3.5 KiB
Rust

use std::{error::Error, path::PathBuf, sync::Arc, time::Duration};
use base64::{engine::general_purpose, Engine};
use colored::Colorize;
use futures::{lock::Mutex, stream, StreamExt};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use md5::{Digest, Md5};
use reqwest::{header::USER_AGENT, Client};
use tokio::{fs::File, io::AsyncWriteExt, time::timeout};
use crate::UA;
const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(30);
const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(60);
async fn write_file(path: &PathBuf, bytes: &[u8]) -> std::io::Result<usize> {
let mut file = File::create(path).await?;
let bw = file.write(bytes).await?;
file.flush().await?;
Ok(bw)
}
pub async fn concurrent_dl(
images: Vec<(String, PathBuf, String)>,
display_url: String,
) -> Result<(usize, usize), Box<dyn Error>> {
let multi = MultiProgress::new();
let bar = multi.add(ProgressBar::new(images.len() as u64));
bar.set_style(
ProgressStyle::with_template("{spinner} {msg} [{wide_bar:.white/gray}] [{pos}/{len}]")
.unwrap(),
);
bar.set_message(display_url);
let error_bar = multi.add(ProgressBar::new(0));
error_bar.set_style(ProgressStyle::with_template("{wide_msg}").unwrap());
let dl_count = Arc::new(Mutex::new(0));
let sk_count = Arc::new(Mutex::new(0));
let client = Client::builder()
.pool_idle_timeout(KEEP_ALIVE_TIMEOUT)
.build()?;
let futures = stream::iter(images.iter().map(|data| async {
let dl_count = Arc::clone(&dl_count);
let _sk_count = Arc::clone(&sk_count);
let client = client.clone();
let (url, path, _expct_md5) = data;
let download_result = timeout(DOWNLOAD_TIMEOUT, async {
let res = client.get(url).header(USER_AGENT, UA).send().await?;
let bytes = res.bytes().await?;
Ok::<_, reqwest::Error>(bytes)
})
.await;
match download_result {
Ok(Ok(bytes)) => {
let mut hasher = Md5::new();
hasher.update(&bytes);
let result = hasher.finalize();
let _b64_md5 = general_purpose::STANDARD.encode(result);
// TODO: Figure out how the MD5 should be converted before uncommenting the following filtering condition
// if b64_md5 != *expct_md5 {
// error_bar.set_message(format!("File skipped due to mismatched MD5 (expected {expct_md5}, got {b64_md5})").red().bold());
// let mut sk_count = sk_count.lock().await;
// *sk_count += 1;
// return;
// }
let _n = write_file(path, &bytes).await.unwrap();
let mut dl_count = dl_count.lock().await;
*dl_count += 1;
}
Err(_) => {
error_bar.set_message(format!(
"{}",
format!("Failed to convert request from {} to bytes", url)
.red()
.bold()
));
}
Ok(Err(_)) => {
error_bar.set_message(format!(
"{}",
format!("Failed to request {}", url).red().bold()
));
}
}
bar.inc(1);
}))
.buffer_unordered(100)
.collect::<Vec<()>>();
futures.await;
bar.finish();
error_bar.finish();
let dl = *dl_count.lock().await;
let sk = *sk_count.lock().await;
Ok((dl, sk))
}