diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..57c2a87 Binary files /dev/null and b/.DS_Store differ diff --git a/scripts/download.py b/scripts/download.py index a5a0bfd..4ad474d 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -179,11 +179,12 @@ def download(download_list: list, output_dir: str, is_clean_cache: bool): for item in tqdm(download_list, desc='Downloading'): repo = item['repo'] rel_path = item['rel_path'] - + + zip_path = os.path.join(output_dir, rel_path) output_path = os.path.join(output_dir, rel_path) output_path = output_path.replace('.zip', '') # skip if already exists locally - if os.path.exists(output_path): + if os.path.exists(zip_path) or os.path.exists(output_path): succ_count += 1 continue succ = hf_download_path(repo, rel_path, output_dir) @@ -194,13 +195,23 @@ def download(download_list: list, output_dir: str, is_clean_cache: bool): if is_clean_cache: clean_huggingface_cache(output_dir, repo) - # unzip the file - if rel_path.endswith('.zip'): - zip_file = join(output_dir, rel_path) - with zipfile.ZipFile(zip_file, 'r') as zip_ref: - ofile = join(output_dir, os.path.dirname(rel_path)) - zip_ref.extractall(ofile) - os.remove(zip_file) + # # unzip the file + # if rel_path.endswith('.zip'): + # zip_file = join(output_dir, rel_path) + # with zipfile.ZipFile(zip_file, 'r') as zip_ref: + # ofile = join(output_dir, os.path.dirname(rel_path)) + # zip_ref.extractall(ofile) + # os.remove(zip_file) + try: + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(output_path) + os.remove(zip_path) + except zipfile.BadZipFile: + print(f"ERROR: {zip_path} is not a zip file. Download failed.") + continue + except Exception as e: + print(f"An error occurred while unzipping the file: {e}") + continue else: print(f'Download {rel_path} failed') @@ -224,7 +235,7 @@ def download_dataset(args): os.makedirs(output_dir, exist_ok=True) download_list = get_download_list(subset_opt, hash_name, reso_opt, file_type, output_dir) - return download(download_list, output_dir, is_clean_cache) + return download(download_list[:500], output_dir, is_clean_cache) if __name__ == '__main__':