diff --git a/bulkredditdownloader/downloader.py b/bulkredditdownloader/downloader.py index d6a5e5f..56be776 100644 --- a/bulkredditdownloader/downloader.py +++ b/bulkredditdownloader/downloader.py @@ -13,10 +13,10 @@ import appdirs import praw import praw.models -from bulkredditdownloader.site_authenticator import SiteAuthenticator from bulkredditdownloader.download_filter import DownloadFilter from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError from bulkredditdownloader.file_name_formatter import FileNameFormatter +from bulkredditdownloader.site_authenticator import SiteAuthenticator from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory logger = logging.getLogger(__name__) @@ -107,7 +107,11 @@ class RedditDownloader: if self.args.subreddit: subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit] if self.args.search: - return [reddit.search(self.args.search, sort=self.sort_filter.name.lower()) for reddit in subreddits] + return [ + reddit.search( + self.args.search, + sort=self.sort_filter.name.lower(), + limit=self.args.limit) for reddit in subreddits] else: sort_function = self._determine_sort_function() return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits] @@ -116,8 +120,8 @@ class RedditDownloader: def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]: supplied_submissions = [] - for url in self.args.link: - supplied_submissions.append(self.reddit_instance.submission(url=url)) + for sub_id in self.args.link: + supplied_submissions.append(self.reddit_instance.submission(id=sub_id)) return [supplied_submissions] def _determine_sort_function(self): @@ -162,29 +166,22 @@ class RedditDownloader: return [] def _create_file_name_formatter(self) -> FileNameFormatter: - return FileNameFormatter(self.args.set_filename, self.args.set_folderpath) + return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme) def _create_time_filter(self) -> RedditTypes.TimeType: try: - return RedditTypes.TimeType[self.args.sort.upper()] + return RedditTypes.TimeType[self.args.time.upper()] except (KeyError, AttributeError): return RedditTypes.TimeType.ALL def _create_sort_filter(self) -> RedditTypes.SortType: try: - return RedditTypes.SortType[self.args.time.upper()] + return RedditTypes.SortType[self.args.sort.upper()] except (KeyError, AttributeError): return RedditTypes.SortType.HOT def _create_download_filter(self) -> DownloadFilter: - formats = { - "videos": [".mp4", ".webm"], - "images": [".jpg", ".jpeg", ".png", ".bmp"], - "gifs": [".gif"], - "self": [] - } - excluded_extensions = [extension for ext_type in self.args.skip for extension in formats.get(ext_type, ())] - return DownloadFilter(excluded_extensions, self.args.skip_domain) + return DownloadFilter(self.args.skip, self.args.skip_domain) def _create_authenticator(self) -> SiteAuthenticator: raise NotImplementedError diff --git a/bulkredditdownloader/tests/test_downloader.py b/bulkredditdownloader/tests/test_downloader.py new file mode 100644 index 0000000..c981ef7 --- /dev/null +++ b/bulkredditdownloader/tests/test_downloader.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +import argparse +from pathlib import Path +from unittest.mock import MagicMock + +import praw +import praw.models +import pytest + +from bulkredditdownloader.download_filter import DownloadFilter +from bulkredditdownloader.downloader import RedditDownloader, RedditTypes +from bulkredditdownloader.errors import BulkDownloaderException +from bulkredditdownloader.file_name_formatter import FileNameFormatter +from bulkredditdownloader.site_authenticator import SiteAuthenticator + + +@pytest.fixture() +def args() -> argparse.Namespace: + args = argparse.Namespace() + + args.directory = '.' + args.verbose = 0 + args.link = [] + args.submitted = False + args.upvoted = False + args.subreddit = [] + args.multireddit = [] + args.user = None + args.search = None + args.sort = 'hot' + args.limit = None + args.time = 'all' + args.skip = [] + args.skip_domain = [] + args.set_folder_scheme = '{SUBREDDIT}' + args.set_file_scheme = '{REDDITOR}_{TITLE}_{POSTID}' + args.no_dupes = False + + return args + + +@pytest.fixture() +def downloader_mock(args: argparse.Namespace): + mock_downloader = MagicMock() + mock_downloader.args = args + return mock_downloader + + +def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): + downloader_mock.args.directory = tmp_path / 'test' + RedditDownloader._determine_directories(downloader_mock) + + assert Path(tmp_path / 'test').exists() + assert downloader_mock.logfile_directory == Path(tmp_path / 'test' / 'LOG_FILES') + assert downloader_mock.logfile_directory.exists() + + +@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( + ([], []), + (['.test'], ['test.com']), +)) +def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): + downloader_mock.args.skip = skip_extensions + downloader_mock.args.skip_domain = skip_domains + result = RedditDownloader._create_download_filter(downloader_mock) + + assert isinstance(result, DownloadFilter) + assert result.excluded_domains == skip_domains + assert result.excluded_extensions == skip_extensions + + +@pytest.mark.parametrize(('test_time', 'expected'), ( + ('all', 'all'), + ('hour', 'hour'), + ('day', 'day'), + ('week', 'week'), + ('random', 'all'), + ('', 'all'), +)) +def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): + downloader_mock.args.time = test_time + result = RedditDownloader._create_time_filter(downloader_mock) + + assert isinstance(result, RedditTypes.TimeType) + assert result.name.lower() == expected + + +@pytest.mark.parametrize(('test_sort', 'expected'), ( + ('', 'hot'), + ('hot', 'hot'), + ('controversial', 'controversial'), + ('new', 'new'), +)) +def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): + downloader_mock.args.sort = test_sort + result = RedditDownloader._create_sort_filter(downloader_mock) + + assert isinstance(result, RedditTypes.SortType) + assert result.name.lower() == expected + + +@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( + ('{POSTID}', '{SUBREDDIT}'), + ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), +)) +def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): + downloader_mock.args.set_file_scheme = test_file_scheme + downloader_mock.args.set_folder_scheme = test_folder_scheme + result = RedditDownloader._create_file_name_formatter(downloader_mock) + + assert isinstance(result, FileNameFormatter) + assert result.file_format_string == test_file_scheme + assert result.directory_format_string == test_folder_scheme + + +@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( + ('', ''), + ('{POSTID}', ''), + ('', '{SUBREDDIT}'), + ('test', '{SUBREDDIT}'), + ('{POSTID}', 'test'), +)) +def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): + downloader_mock.args.set_file_scheme = test_file_scheme + downloader_mock.args.set_folder_scheme = test_folder_scheme + with pytest.raises(BulkDownloaderException): + RedditDownloader._create_file_name_formatter(downloader_mock) + + +@pytest.mark.skip +def test_create_authenticator(downloader_mock: MagicMock): + result = RedditDownloader._create_authenticator(downloader_mock) + assert isinstance(result, SiteAuthenticator) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize('test_submission_ids', ( + ('lvpf4l',), + ('lvpf4l', 'lvqnsn'), + ('lvpf4l', 'lvqnsn', 'lvl9kd'), +)) +def test_get_submissions_from_link( + test_submission_ids: list[str], + reddit_instance: praw.Reddit, + downloader_mock: MagicMock): + downloader_mock.args.link = test_submission_ids + downloader_mock.reddit_instance = reddit_instance + results = RedditDownloader._get_submissions_from_link(downloader_mock) + assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res]) + assert len(results[0]) == len(test_submission_ids) + + +@pytest.mark.skip +def test_load_config(downloader_mock: MagicMock): + raise NotImplementedError + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_subreddits', 'limit'), ( + (('Futurology',), 10), + (('Futurology',), 20), + (('Futurology', 'Python'), 10), + (('Futurology',), 100), + (('Futurology',), 0), +)) +def test_get_subreddit_normal( + test_subreddits: list[str], + limit: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit): + downloader_mock.reddit_instance = reddit_instance + downloader_mock.args.subreddit = test_subreddits + downloader_mock.args.limit = limit + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + results = RedditDownloader._get_subreddits(downloader_mock) + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + assert all([res.subreddit.display_name for res in results]) + if limit is not None: + assert len(results) == (limit * len(test_subreddits)) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), ( + (('Python',), 'scraper', 10), + (('Python',), '', 10), +)) +def test_get_subreddit_search( + test_subreddits: list[str], + search_term: str, + limit: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit): + downloader_mock.reddit_instance = reddit_instance + downloader_mock.args.subreddit = test_subreddits + downloader_mock.args.limit = limit + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.search = search_term + results = RedditDownloader._get_subreddits(downloader_mock) + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + assert all([res.subreddit.display_name for res in results]) + if limit is not None: + assert len(results) == (limit * len(test_subreddits)) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_get_subreddits_search_bad(): + raise NotImplementedError + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_get_multireddits(): + raise NotImplementedError + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_get_user_submissions(): + raise NotImplementedError + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_get_user_upvoted(): + raise NotImplementedError + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_get_user_saved(): + raise NotImplementedError + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_download_submission(): + raise NotImplementedError + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_download_submission_file_exists(): + raise NotImplementedError + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_download_submission_hash_exists(): + raise NotImplementedError