Add some more tests for RedditDownloader

This commit is contained in:
Serene-Arc
2021-03-05 13:31:40 +10:00
committed by Ali Parlakci
parent 6f86dbd552
commit b705c31630
2 changed files with 158 additions and 51 deletions

View File

@@ -8,13 +8,15 @@ import socket
from datetime import datetime from datetime import datetime
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path from pathlib import Path
from typing import Iterator
import appdirs import appdirs
import praw import praw
import praw.models import praw.models
import prawcore
import bulkredditdownloader.errors as errors
from bulkredditdownloader.download_filter import DownloadFilter from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
from bulkredditdownloader.file_name_formatter import FileNameFormatter from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.site_authenticator import SiteAuthenticator from bulkredditdownloader.site_authenticator import SiteAuthenticator
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
@@ -54,6 +56,7 @@ class RedditDownloader:
self.sort_filter = self._create_sort_filter() self.sort_filter = self._create_sort_filter()
self.file_name_formatter = self._create_file_name_formatter() self.file_name_formatter = self._create_file_name_formatter()
self.authenticator = self._create_authenticator() self.authenticator = self._create_authenticator()
self._resolve_user_name()
self._determine_directories() self._determine_directories()
self._create_file_logger() self._create_file_logger()
self.master_hash_list = [] self.master_hash_list = []
@@ -118,6 +121,10 @@ class RedditDownloader:
else: else:
return [] return []
def _resolve_user_name(self):
if self.args.user == 'me':
self.args.user = self.reddit_instance.user.me()
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]: def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = [] supplied_submissions = []
for sub_id in self.args.link: for sub_id in self.args.link:
@@ -135,35 +142,52 @@ class RedditDownloader:
sort_function = praw.models.Subreddit.hot sort_function = praw.models.Subreddit.hot
return sort_function return sort_function
def _get_multireddits(self) -> list[praw.models.ListingGenerator]: def _get_multireddits(self) -> list[Iterator]:
if self.args.multireddit: if self.args.multireddit:
if self.authenticated: if self.authenticated:
return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in self.args.multireddit] if self.args.user:
sort_function = self._determine_sort_function()
return [
sort_function(self.reddit_instance.multireddit(
self.args.user,
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
else:
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
else: else:
raise RedditAuthenticationError('Accessing multireddits requires authentication') raise errors.RedditAuthenticationError('Accessing multireddits requires authentication')
else: else:
return [] return []
def _get_user_data(self) -> list[praw.models.ListingGenerator]: def _get_user_data(self) -> list[Iterator]:
if any((self.args.upvoted, self.args.submitted, self.args.saved)): if self.args.user:
if self.authenticated: if not self._check_user_existence(self.args.user):
generators = [] raise errors.RedditUserError(f'User {self.args.user} does not exist')
sort_function = self._determine_sort_function() generators = []
sort_function = self._determine_sort_function()
if self.args.submitted:
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
raise errors.RedditAuthenticationError('Accessing user lists requires authentication')
else:
if self.args.upvoted: if self.args.upvoted:
generators.append(self.reddit_instance.redditor(self.args.user).upvoted) generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
if self.args.submitted:
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit))
if self.args.saved: if self.args.saved:
generators.append(self.reddit_instance.redditor(self.args.user).saved) generators.append(self.reddit_instance.redditor(self.args.user).saved)
return generators
return generators
else:
raise RedditAuthenticationError('Accessing user lists requires authentication')
else: else:
return [] raise errors.BulkDownloaderException('A user must be supplied to download user data')
def _check_user_existence(self, name: str) -> bool:
user = self.reddit_instance.redditor(name=name)
try:
if not user.id:
return False
except prawcore.exceptions.NotFound:
return False
return True
def _create_file_name_formatter(self) -> FileNameFormatter: def _create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme) return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
@@ -198,10 +222,10 @@ class RedditDownloader:
try: try:
downloader_class = DownloadFactory.pull_lever(submission.url) downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission) downloader = downloader_class(submission)
except NotADownloadableLinkError as e: except errors.NotADownloadableLinkError as e:
logger.error('Could not download submission {}: {}'.format(submission.name, e)) logger.error('Could not download submission {}: {}'.format(submission.name, e))
return return
if self.args.no_download: if self.args.no_download:
logger.info('Skipping download for submission {}'.format(submission.id)) logger.info('Skipping download for submission {}'.format(submission.id))
else: else:

View File

@@ -3,6 +3,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Iterator
from unittest.mock import MagicMock from unittest.mock import MagicMock
import praw import praw
@@ -11,7 +12,7 @@ import pytest
from bulkredditdownloader.download_filter import DownloadFilter from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.downloader import RedditDownloader, RedditTypes from bulkredditdownloader.downloader import RedditDownloader, RedditTypes
from bulkredditdownloader.errors import BulkDownloaderException from bulkredditdownloader.errors import BulkDownloaderException, RedditAuthenticationError, RedditUserError
from bulkredditdownloader.file_name_formatter import FileNameFormatter from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.site_authenticator import SiteAuthenticator from bulkredditdownloader.site_authenticator import SiteAuthenticator
@@ -25,6 +26,7 @@ def args() -> argparse.Namespace:
args.link = [] args.link = []
args.submitted = False args.submitted = False
args.upvoted = False args.upvoted = False
args.saved = False
args.subreddit = [] args.subreddit = []
args.multireddit = [] args.multireddit = []
args.user = None args.user = None
@@ -48,6 +50,14 @@ def downloader_mock(args: argparse.Namespace):
return mock_downloader return mock_downloader
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
if result_limit is not None:
assert len(results) == result_limit
return results
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
downloader_mock.args.directory = tmp_path / 'test' downloader_mock.args.directory = tmp_path / 'test'
RedditDownloader._determine_directories(downloader_mock) RedditDownloader._determine_directories(downloader_mock)
@@ -172,17 +182,15 @@ def test_get_subreddit_normal(
limit: int, limit: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit): reddit_instance: praw.Reddit):
downloader_mock.reddit_instance = reddit_instance
downloader_mock.args.subreddit = test_subreddits
downloader_mock.args.limit = limit
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock) results = RedditDownloader._get_subreddits(downloader_mock)
results = [sub for res in results for sub in res] results = assert_all_results_are_submissions(
assert all([isinstance(res, praw.models.Submission) for res in results]) (limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name for res in results]) assert all([res.subreddit.display_name in test_subreddits for res in results])
if limit is not None:
assert len(results) == (limit * len(test_subreddits))
@pytest.mark.online @pytest.mark.online
@@ -190,6 +198,7 @@ def test_get_subreddit_normal(
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), ( @pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
(('Python',), 'scraper', 10), (('Python',), 'scraper', 10),
(('Python',), '', 10), (('Python',), '', 10),
(('Python',), 'djsdsgewef', 0),
)) ))
def test_get_subreddit_search( def test_get_subreddit_search(
test_subreddits: list[str], test_subreddits: list[str],
@@ -197,39 +206,99 @@ def test_get_subreddit_search(
limit: int, limit: int,
downloader_mock: MagicMock, downloader_mock: MagicMock,
reddit_instance: praw.Reddit): reddit_instance: praw.Reddit):
downloader_mock.reddit_instance = reddit_instance downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock)
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name in test_subreddits for res in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
('helen_darten', ('cuteanimalpics',), 10),
('korfor', ('chess',), 100),
))
# Good sources at https://www.reddit.com/r/multihub/
def test_get_multireddits_public(
test_user: str,
test_multireddits: list[str],
limit: int,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = test_user
downloader_mock.reddit_instance = reddit_instance
results = RedditDownloader._get_multireddits(downloader_mock)
assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results)
@pytest.mark.online
@pytest.mark.reddit
def test_get_multireddits_no_user(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.multireddit = ['test']
with pytest.raises(BulkDownloaderException):
RedditDownloader._get_multireddits(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
def test_get_multireddits_not_authenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.multireddit = ['test']
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
with pytest.raises(RedditAuthenticationError):
RedditDownloader._get_multireddits(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'limit'), (
('danigirl3694', 10),
('danigirl3694', 50),
('CapitanHam', None),
))
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.limit = limit downloader_mock.args.limit = limit
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.search = search_term downloader_mock.args.submitted = True
results = RedditDownloader._get_subreddits(downloader_mock) downloader_mock.args.user = test_user
results = [sub for res in results for sub in res] downloader_mock.authenticated = False
assert all([isinstance(res, praw.models.Submission) for res in results]) downloader_mock.reddit_instance = reddit_instance
assert all([res.subreddit.display_name for res in results]) results = RedditDownloader._get_user_data(downloader_mock)
if limit is not None: results = assert_all_results_are_submissions(limit, results)
assert len(results) == (limit * len(test_subreddits)) assert all([res.author.name == test_user for res in results])
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skip def test_get_user_no_user(downloader_mock: MagicMock):
def test_get_subreddits_search_bad(): with pytest.raises(BulkDownloaderException):
raise NotImplementedError RedditDownloader._get_user_data(downloader_mock)
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skip @pytest.mark.parametrize('test_user', (
def test_get_multireddits(): 'rockcanopicjartheme',
raise NotImplementedError 'exceptionalcatfishracecarbatter',
))
def test_get_user_nonexistent_user(test_user: str, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
@pytest.mark.online downloader_mock.reddit_instance = reddit_instance
@pytest.mark.reddit downloader_mock.args.user = test_user
@pytest.mark.skip downloader_mock._check_user_existence.return_value = RedditDownloader._check_user_existence(
def test_get_user_submissions(): downloader_mock, test_user)
raise NotImplementedError with pytest.raises(RedditUserError):
RedditDownloader._get_user_data(downloader_mock)
@pytest.mark.online @pytest.mark.online
@@ -239,6 +308,13 @@ def test_get_user_upvoted():
raise NotImplementedError raise NotImplementedError
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip
def test_get_user_upvoted_unauthenticated():
raise NotImplementedError
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skip @pytest.mark.skip
@@ -246,6 +322,13 @@ def test_get_user_saved():
raise NotImplementedError raise NotImplementedError
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip
def test_get_user_saved_unauthenticated():
raise NotImplementedError
@pytest.mark.online @pytest.mark.online
@pytest.mark.reddit @pytest.mark.reddit
@pytest.mark.skip @pytest.mark.skip