From 2b9dc1b96cfd02e04f35e7ed456d8557de78a3fd Mon Sep 17 00:00:00 2001 From: Serene-Arc Date: Thu, 11 Mar 2021 17:18:21 +1000 Subject: [PATCH] Allow subreddits and multireddits to fail individually --- bulkredditdownloader/downloader.py | 56 +++++++++++-------- bulkredditdownloader/tests/test_downloader.py | 18 ------ 2 files changed, 33 insertions(+), 41 deletions(-) diff --git a/bulkredditdownloader/downloader.py b/bulkredditdownloader/downloader.py index 83447a6..f90dbf8 100644 --- a/bulkredditdownloader/downloader.py +++ b/bulkredditdownloader/downloader.py @@ -14,6 +14,7 @@ from typing import Iterator import appdirs import praw +import praw.exceptions import praw.models import prawcore @@ -166,17 +167,26 @@ class RedditDownloader: def _get_subreddits(self) -> list[praw.models.ListingGenerator]: if self.args.subreddit: - subreddits = [self._sanitise_subreddit_name(subreddit) for subreddit in self.args.subreddit] - subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in subreddits] - if self.args.search: - return [ - reddit.search( - self.args.search, - sort=self.sort_filter.name.lower(), - limit=self.args.limit) for reddit in subreddits] - else: - sort_function = self._determine_sort_function() - return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits] + out = [] + sort_function = self._determine_sort_function() + for reddit in self.args.subreddit: + try: + reddit = self._sanitise_subreddit_name(reddit) + reddit = self.reddit_instance.subreddit(reddit) + if self.args.search: + out.append( + reddit.search( + self.args.search, + sort=self.sort_filter.name.lower(), + limit=self.args.limit)) + logger.debug( + f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"') + else: + out.append(sort_function(reddit, limit=self.args.limit)) + logger.debug(f'Added submissions from subreddit {reddit}') + except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: + logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') + return out else: return [] @@ -211,18 +221,18 @@ class RedditDownloader: def _get_multireddits(self) -> list[Iterator]: if self.args.multireddit: - if self.authenticated: - if self.args.user: - sort_function = self._determine_sort_function() - multireddits = [self._sanitise_subreddit_name(multi) for multi in self.args.multireddit] - return [ - sort_function(self.reddit_instance.multireddit( - self.args.user, - m_reddit_choice), limit=self.args.limit) for m_reddit_choice in multireddits] - else: - raise errors.BulkDownloaderException('A user must be provided to download a multireddit') - else: - raise errors.RedditAuthenticationError('Accessing multireddits requires authentication') + out = [] + sort_function = self._determine_sort_function() + for multi in self.args.multireddit: + try: + multi = self._sanitise_subreddit_name(multi) + out.append(sort_function( + self.reddit_instance.multireddit(self.args.user, multi), + limit=self.args.limit)) + logger.debug(f'Added submissions from multireddit {multi}') + except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: + logger.error(f'Failed to get submissions for multireddit {multi}: {e}') + return out else: return [] diff --git a/bulkredditdownloader/tests/test_downloader.py b/bulkredditdownloader/tests/test_downloader.py index 9aa32b9..292da53 100644 --- a/bulkredditdownloader/tests/test_downloader.py +++ b/bulkredditdownloader/tests/test_downloader.py @@ -217,24 +217,6 @@ def test_get_multireddits_public( assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results) -@pytest.mark.online -@pytest.mark.reddit -def test_get_multireddits_no_user(downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.args.multireddit = ['test'] - with pytest.raises(BulkDownloaderException): - RedditDownloader._get_multireddits(downloader_mock) - - -@pytest.mark.online -@pytest.mark.reddit -def test_get_multireddits_not_authenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit): - downloader_mock.args.multireddit = ['test'] - downloader_mock.authenticated = False - downloader_mock.reddit_instance = reddit_instance - with pytest.raises(RedditAuthenticationError): - RedditDownloader._get_multireddits(downloader_mock) - - @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('test_user', 'limit'), (