diff --git a/bulkredditdownloader/downloader.py b/bulkredditdownloader/downloader.py index c9149a8..e083974 100644 --- a/bulkredditdownloader/downloader.py +++ b/bulkredditdownloader/downloader.py @@ -8,13 +8,15 @@ import socket from datetime import datetime from enum import Enum, auto from pathlib import Path +from typing import Iterator import appdirs import praw import praw.models +import prawcore +import bulkredditdownloader.errors as errors from bulkredditdownloader.download_filter import DownloadFilter -from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError from bulkredditdownloader.file_name_formatter import FileNameFormatter from bulkredditdownloader.site_authenticator import SiteAuthenticator from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory @@ -54,6 +56,7 @@ class RedditDownloader: self.sort_filter = self._create_sort_filter() self.file_name_formatter = self._create_file_name_formatter() self.authenticator = self._create_authenticator() + self._resolve_user_name() self._determine_directories() self._create_file_logger() self.master_hash_list = [] @@ -118,6 +121,10 @@ class RedditDownloader: else: return [] + def _resolve_user_name(self): + if self.args.user == 'me': + self.args.user = self.reddit_instance.user.me() + def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]: supplied_submissions = [] for sub_id in self.args.link: @@ -135,35 +142,52 @@ class RedditDownloader: sort_function = praw.models.Subreddit.hot return sort_function - def _get_multireddits(self) -> list[praw.models.ListingGenerator]: + def _get_multireddits(self) -> list[Iterator]: if self.args.multireddit: if self.authenticated: - return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in self.args.multireddit] + if self.args.user: + sort_function = self._determine_sort_function() + return [ + sort_function(self.reddit_instance.multireddit( + self.args.user, + m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit] + else: + raise errors.BulkDownloaderException('A user must be provided to download a multireddit') else: - raise RedditAuthenticationError('Accessing multireddits requires authentication') + raise errors.RedditAuthenticationError('Accessing multireddits requires authentication') else: return [] - def _get_user_data(self) -> list[praw.models.ListingGenerator]: - if any((self.args.upvoted, self.args.submitted, self.args.saved)): - if self.authenticated: - generators = [] - sort_function = self._determine_sort_function() + def _get_user_data(self) -> list[Iterator]: + if self.args.user: + if not self._check_user_existence(self.args.user): + raise errors.RedditUserError(f'User {self.args.user} does not exist') + generators = [] + sort_function = self._determine_sort_function() + if self.args.submitted: + generators.append( + sort_function( + self.reddit_instance.redditor(self.args.user).submissions, + limit=self.args.limit)) + if not self.authenticated and any((self.args.upvoted, self.args.saved)): + raise errors.RedditAuthenticationError('Accessing user lists requires authentication') + else: if self.args.upvoted: generators.append(self.reddit_instance.redditor(self.args.user).upvoted) - if self.args.submitted: - generators.append( - sort_function( - self.reddit_instance.redditor(self.args.user).submissions, - limit=self.args.limit)) if self.args.saved: generators.append(self.reddit_instance.redditor(self.args.user).saved) - - return generators - else: - raise RedditAuthenticationError('Accessing user lists requires authentication') + return generators else: - return [] + raise errors.BulkDownloaderException('A user must be supplied to download user data') + + def _check_user_existence(self, name: str) -> bool: + user = self.reddit_instance.redditor(name=name) + try: + if not user.id: + return False + except prawcore.exceptions.NotFound: + return False + return True def _create_file_name_formatter(self) -> FileNameFormatter: return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme) @@ -198,10 +222,10 @@ class RedditDownloader: try: downloader_class = DownloadFactory.pull_lever(submission.url) downloader = downloader_class(submission) - except NotADownloadableLinkError as e: + except errors.NotADownloadableLinkError as e: logger.error('Could not download submission {}: {}'.format(submission.name, e)) return - + if self.args.no_download: logger.info('Skipping download for submission {}'.format(submission.id)) else: diff --git a/bulkredditdownloader/tests/test_downloader.py b/bulkredditdownloader/tests/test_downloader.py index c981ef7..1aa58c6 100644 --- a/bulkredditdownloader/tests/test_downloader.py +++ b/bulkredditdownloader/tests/test_downloader.py @@ -3,6 +3,7 @@ import argparse from pathlib import Path +from typing import Iterator from unittest.mock import MagicMock import praw @@ -11,7 +12,7 @@ import pytest from bulkredditdownloader.download_filter import DownloadFilter from bulkredditdownloader.downloader import RedditDownloader, RedditTypes -from bulkredditdownloader.errors import BulkDownloaderException +from bulkredditdownloader.errors import BulkDownloaderException, RedditAuthenticationError, RedditUserError from bulkredditdownloader.file_name_formatter import FileNameFormatter from bulkredditdownloader.site_authenticator import SiteAuthenticator @@ -25,6 +26,7 @@ def args() -> argparse.Namespace: args.link = [] args.submitted = False args.upvoted = False + args.saved = False args.subreddit = [] args.multireddit = [] args.user = None @@ -48,6 +50,14 @@ def downloader_mock(args: argparse.Namespace): return mock_downloader +def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]): + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + if result_limit is not None: + assert len(results) == result_limit + return results + + def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): downloader_mock.args.directory = tmp_path / 'test' RedditDownloader._determine_directories(downloader_mock) @@ -172,17 +182,15 @@ def test_get_subreddit_normal( limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.reddit_instance = reddit_instance - downloader_mock.args.subreddit = test_subreddits - downloader_mock.args.limit = limit downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.args.limit = limit + downloader_mock.args.subreddit = test_subreddits + downloader_mock.reddit_instance = reddit_instance downloader_mock.sort_filter = RedditTypes.SortType.HOT results = RedditDownloader._get_subreddits(downloader_mock) - results = [sub for res in results for sub in res] - assert all([isinstance(res, praw.models.Submission) for res in results]) - assert all([res.subreddit.display_name for res in results]) - if limit is not None: - assert len(results) == (limit * len(test_subreddits)) + results = assert_all_results_are_submissions( + (limit * len(test_subreddits)) if limit else None, results) + assert all([res.subreddit.display_name in test_subreddits for res in results]) @pytest.mark.online @@ -190,6 +198,7 @@ def test_get_subreddit_normal( @pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), ( (('Python',), 'scraper', 10), (('Python',), '', 10), + (('Python',), 'djsdsgewef', 0), )) def test_get_subreddit_search( test_subreddits: list[str], @@ -197,39 +206,99 @@ def test_get_subreddit_search( limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.reddit_instance = reddit_instance + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.args.limit = limit + downloader_mock.args.search = search_term downloader_mock.args.subreddit = test_subreddits + downloader_mock.reddit_instance = reddit_instance + downloader_mock.sort_filter = RedditTypes.SortType.HOT + results = RedditDownloader._get_subreddits(downloader_mock) + results = assert_all_results_are_submissions( + (limit * len(test_subreddits)) if limit else None, results) + assert all([res.subreddit.display_name in test_subreddits for res in results]) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), ( + ('helen_darten', ('cuteanimalpics',), 10), + ('korfor', ('chess',), 100), +)) +# Good sources at https://www.reddit.com/r/multihub/ +def test_get_multireddits_public( + test_user: str, + test_multireddits: list[str], + limit: int, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock): + downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot + downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.limit = limit + downloader_mock.args.multireddit = test_multireddits + downloader_mock.args.user = test_user + downloader_mock.reddit_instance = reddit_instance + results = RedditDownloader._get_multireddits(downloader_mock) + assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results) + + +@pytest.mark.online +@pytest.mark.reddit +def test_get_multireddits_no_user(downloader_mock: MagicMock, reddit_instance: praw.Reddit): + downloader_mock.args.multireddit = ['test'] + with pytest.raises(BulkDownloaderException): + RedditDownloader._get_multireddits(downloader_mock) + + +@pytest.mark.online +@pytest.mark.reddit +def test_get_multireddits_not_authenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit): + downloader_mock.args.multireddit = ['test'] + downloader_mock.authenticated = False + downloader_mock.reddit_instance = reddit_instance + with pytest.raises(RedditAuthenticationError): + RedditDownloader._get_multireddits(downloader_mock) + + +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.parametrize(('test_user', 'limit'), ( + ('danigirl3694', 10), + ('danigirl3694', 50), + ('CapitanHam', None), +)) +def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock.args.limit = limit downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT - downloader_mock.args.search = search_term - results = RedditDownloader._get_subreddits(downloader_mock) - results = [sub for res in results for sub in res] - assert all([isinstance(res, praw.models.Submission) for res in results]) - assert all([res.subreddit.display_name for res in results]) - if limit is not None: - assert len(results) == (limit * len(test_subreddits)) + downloader_mock.args.submitted = True + downloader_mock.args.user = test_user + downloader_mock.authenticated = False + downloader_mock.reddit_instance = reddit_instance + results = RedditDownloader._get_user_data(downloader_mock) + results = assert_all_results_are_submissions(limit, results) + assert all([res.author.name == test_user for res in results]) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skip -def test_get_subreddits_search_bad(): - raise NotImplementedError +def test_get_user_no_user(downloader_mock: MagicMock): + with pytest.raises(BulkDownloaderException): + RedditDownloader._get_user_data(downloader_mock) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skip -def test_get_multireddits(): - raise NotImplementedError - - -@pytest.mark.online -@pytest.mark.reddit -@pytest.mark.skip -def test_get_user_submissions(): - raise NotImplementedError +@pytest.mark.parametrize('test_user', ( + 'rockcanopicjartheme', + 'exceptionalcatfishracecarbatter', +)) +def test_get_user_nonexistent_user(test_user: str, downloader_mock: MagicMock, reddit_instance: praw.Reddit): + downloader_mock.reddit_instance = reddit_instance + downloader_mock.args.user = test_user + downloader_mock._check_user_existence.return_value = RedditDownloader._check_user_existence( + downloader_mock, test_user) + with pytest.raises(RedditUserError): + RedditDownloader._get_user_data(downloader_mock) @pytest.mark.online @@ -239,6 +308,13 @@ def test_get_user_upvoted(): raise NotImplementedError +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_get_user_upvoted_unauthenticated(): + raise NotImplementedError + + @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip @@ -246,6 +322,13 @@ def test_get_user_saved(): raise NotImplementedError +@pytest.mark.online +@pytest.mark.reddit +@pytest.mark.skip +def test_get_user_saved_unauthenticated(): + raise NotImplementedError + + @pytest.mark.online @pytest.mark.reddit @pytest.mark.skip