diff --git a/bulkredditdownloader/downloader.py b/bulkredditdownloader/downloader.py index 2e00c42..bf26bbd 100644 --- a/bulkredditdownloader/downloader.py +++ b/bulkredditdownloader/downloader.py @@ -74,6 +74,8 @@ class RedditDownloader: self._resolve_user_name() self.master_hash_list = [] + if self.args.search_existing: + self.master_hash_list.extend(self.scan_existing_files(self.download_directory)) self.authenticator = self._create_authenticator() logger.log(9, 'Created site authenticator') @@ -302,37 +304,39 @@ class RedditDownloader: self._download_submission(submission) def _download_submission(self, submission: praw.models.Submission): - if self.download_filter.check_url(submission.url): + if not self.download_filter.check_url(submission.url): + logger.debug(f'Download filter remove submission {submission.id} with URL {submission.url}') + return + try: + downloader_class = DownloadFactory.pull_lever(submission.url) + downloader = downloader_class(submission) + except errors.NotADownloadableLinkError as e: + logger.error(f'Could not download submission {submission.name}: {e}') + return - try: - downloader_class = DownloadFactory.pull_lever(submission.url) - downloader = downloader_class(submission) - except errors.NotADownloadableLinkError as e: - logger.error(f'Could not download submission {submission.name}: {e}') - return - - content = downloader.find_resources(self.authenticator) - for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory): - if destination.exists(): - logger.warning(f'File already exists: {destination}') + content = downloader.find_resources(self.authenticator) + for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory): + if destination.exists(): + logger.warning(f'File already exists: {destination}') + else: + res.download() + if res.hash.hexdigest() in self.master_hash_list and self.args.no_dupes: + logger.warning( + f'Resource from "{res.url}" and hash "{res.hash.hexdigest()}" downloaded elsewhere') else: - res.download() - if res.hash.hexdigest() in self.master_hash_list and self.args.no_dupes: - logger.warning( - f'Resource from "{res.url}" and hash "{res.hash.hexdigest()}" downloaded elsewhere') - else: - # TODO: consider making a hard link/symlink here - destination.parent.mkdir(parents=True, exist_ok=True) - with open(destination, 'wb') as file: - file.write(res.content) - logger.debug(f'Written file to {destination}') - self.master_hash_list.append(res.hash.hexdigest()) - logger.debug(f'Hash added to master list: {res.hash.hexdigest()}') - logger.info(f'Downloaded submission {submission.name}') + # TODO: consider making a hard link/symlink here + destination.parent.mkdir(parents=True, exist_ok=True) + with open(destination, 'wb') as file: + file.write(res.content) + logger.debug(f'Written file to {destination}') + self.master_hash_list.append(res.hash.hexdigest()) + logger.debug(f'Hash added to master list: {res.hash.hexdigest()}') + logger.info(f'Downloaded submission {submission.name}') - def scan_existing_files(self) -> list[str]: + @staticmethod + def scan_existing_files(directory: Path) -> list[str]: files = [] - for (dirpath, dirnames, filenames) in os.walk(self.download_directory): + for (dirpath, dirnames, filenames) in os.walk(directory): files.extend([Path(dirpath, file) for file in filenames]) logger.info(f'Calculating hashes for {len(files)} files') hash_list = [] diff --git a/bulkredditdownloader/tests/test_downloader.py b/bulkredditdownloader/tests/test_downloader.py index 292da53..4721ede 100644 --- a/bulkredditdownloader/tests/test_downloader.py +++ b/bulkredditdownloader/tests/test_downloader.py @@ -388,8 +388,7 @@ def test_sanitise_subreddit_name(test_name: str, expected: str): assert result == expected -def test_search_existing_files(downloader_mock: MagicMock): - downloader_mock.download_directory = Path('.').resolve().expanduser() - results = RedditDownloader.scan_existing_files(downloader_mock) +def test_search_existing_files(): + results = RedditDownloader.scan_existing_files(Path('.')) assert all([isinstance(result, str) for result in results]) assert len(results) >= 40 diff --git a/bulkredditdownloader/tests/test_integration.py b/bulkredditdownloader/tests/test_integration.py index beb9be1..17067a1 100644 --- a/bulkredditdownloader/tests/test_integration.py +++ b/bulkredditdownloader/tests/test_integration.py @@ -73,14 +73,15 @@ def test_cli_download_multireddit(test_args: list[str], tmp_path: Path): @pytest.mark.reddit @pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests') @pytest.mark.parametrize('test_args', ( - ['--user', 'helen_darten', '-m', 'xxyyzzqwertty', '-L', 10], + ['--user', 'helen_darten', '-m', 'xxyyzzqwerty', '-L', 10], )) def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Failed to get submissions for multireddit xxyyzzqwerty' in result.output + assert 'Failed to get submissions for multireddit' in result.output + assert 'received 404 HTTP response' in result.output @pytest.mark.online @@ -117,3 +118,19 @@ def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp result = runner.invoke(cli, test_args) assert result.exit_code == 0 assert 'To use "me" as a user, an authenticated Reddit instance must be used' in result.output + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.authenticated +@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests') +@pytest.mark.parametrize('test_args', ( + ['--subreddit', 'python', '-L', 10, '--search-existing'], +)) +def test_cli_download_search_existing(test_args: list[str], tmp_path: Path): + Path(tmp_path, 'test.txt').touch() + runner = CliRunner() + test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args + result = runner.invoke(cli, test_args) + assert result.exit_code == 0 + assert 'Calculating hashes for' in result.output