Add some more tests for RedditDownloader

This commit is contained in:
Serene-Arc
2021-03-05 13:31:40 +10:00
committed by Ali Parlakci
parent 6f86dbd552
commit b705c31630
2 changed files with 158 additions and 51 deletions

View File

@@ -8,13 +8,15 @@ import socket
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import Iterator
import appdirs
import praw
import praw.models
import prawcore
import bulkredditdownloader.errors as errors
from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.site_authenticator import SiteAuthenticator
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
@@ -54,6 +56,7 @@ class RedditDownloader:
self.sort_filter = self._create_sort_filter()
self.file_name_formatter = self._create_file_name_formatter()
self.authenticator = self._create_authenticator()
self._resolve_user_name()
self._determine_directories()
self._create_file_logger()
self.master_hash_list = []
@@ -118,6 +121,10 @@ class RedditDownloader:
else:
return []
def _resolve_user_name(self):
if self.args.user == 'me':
self.args.user = self.reddit_instance.user.me()
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
@@ -135,35 +142,52 @@ class RedditDownloader:
sort_function = praw.models.Subreddit.hot
return sort_function
def _get_multireddits(self) -> list[praw.models.ListingGenerator]:
def _get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
if self.authenticated:
return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in self.args.multireddit]
if self.args.user:
sort_function = self._determine_sort_function()
return [
sort_function(self.reddit_instance.multireddit(
self.args.user,
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
else:
raise RedditAuthenticationError('Accessing multireddits requires authentication')
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
else:
raise errors.RedditAuthenticationError('Accessing multireddits requires authentication')
else:
return []
def _get_user_data(self) -> list[praw.models.ListingGenerator]:
if any((self.args.upvoted, self.args.submitted, self.args.saved)):
if self.authenticated:
def _get_user_data(self) -> list[Iterator]:
if self.args.user:
if not self._check_user_existence(self.args.user):
raise errors.RedditUserError(f'User {self.args.user} does not exist')
generators = []
sort_function = self._determine_sort_function()
if self.args.upvoted:
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
if self.args.submitted:
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
raise errors.RedditAuthenticationError('Accessing user lists requires authentication')
else:
if self.args.upvoted:
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
if self.args.saved:
generators.append(self.reddit_instance.redditor(self.args.user).saved)
return generators
else:
raise RedditAuthenticationError('Accessing user lists requires authentication')
else:
return []
raise errors.BulkDownloaderException('A user must be supplied to download user data')
def _check_user_existence(self, name: str) -> bool:
user = self.reddit_instance.redditor(name=name)
try:
if not user.id:
return False
except prawcore.exceptions.NotFound:
return False
return True
def _create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
@@ -198,7 +222,7 @@ class RedditDownloader:
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
except NotADownloadableLinkError as e:
except errors.NotADownloadableLinkError as e:
logger.error('Could not download submission {}: {}'.format(submission.name, e))
return

View File

@@ -3,6 +3,7 @@
import argparse
from pathlib import Path
from typing import Iterator
from unittest.mock import MagicMock
import praw
@@ -11,7 +12,7 @@ import pytest
from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.downloader import RedditDownloader, RedditTypes
from bulkredditdownloader.errors import BulkDownloaderException
from bulkredditdownloader.errors import BulkDownloaderException, RedditAuthenticationError, RedditUserError
from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.site_authenticator import SiteAuthenticator
@@ -25,6 +26,7 @@ def args() -> argparse.Namespace:
args.link = []
args.submitted = False
args.upvoted = False
args.saved = False
args.subreddit = []
args.multireddit = []
args.user = None
@@ -48,6 +50,14 @@ def downloader_mock(args: argparse.Namespace):
return mock_downloader
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
if result_limit is not None:
assert len(results) == result_limit
return results
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
downloader_mock.args.directory = tmp_path / 'test'
RedditDownloader._determine_directories(downloader_mock)
@@ -172,17 +182,15 @@ def test_get_subreddit_normal(
limit: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
downloader_mock.reddit_instance = reddit_instance
downloader_mock.args.subreddit = test_subreddits
downloader_mock.args.limit = limit
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert all([res.subreddit.display_name for res in results])
if limit is not None:
assert len(results) == (limit * len(test_subreddits))
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name in test_subreddits for res in results])
@pytest.mark.online
@@ -190,6 +198,7 @@ def test_get_subreddit_normal(
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
(('Python',), 'scraper', 10),
(('Python',), '', 10),
(('Python',), 'djsdsgewef', 0),
))
def test_get_subreddit_search(
test_subreddits: list[str],
@@ -197,39 +206,99 @@ def test_get_subreddit_search(
limit: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
downloader_mock.reddit_instance = reddit_instance
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock)
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name in test_subreddits for res in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
('helen_darten', ('cuteanimalpics',), 10),
('korfor', ('chess',), 100),
))
# Good sources at https://www.reddit.com/r/multihub/
def test_get_multireddits_public(
test_user: str,
test_multireddits: list[str],
limit: int,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = test_user
downloader_mock.reddit_instance = reddit_instance
results = RedditDownloader._get_multireddits(downloader_mock)
assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results)
@pytest.mark.online
@pytest.mark.reddit
def test_get_multireddits_no_user(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.multireddit = ['test']
with pytest.raises(BulkDownloaderException):
RedditDownloader._get_multireddits(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
def test_get_multireddits_not_authenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.multireddit = ['test']
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
with pytest.raises(RedditAuthenticationError):
RedditDownloader._get_multireddits(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'limit'), (
('danigirl3694', 10),
('danigirl3694', 50),
('CapitanHam', None),
))
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.limit = limit
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.search = search_term
results = RedditDownloader._get_subreddits(downloader_mock)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert all([res.subreddit.display_name for res in results])
if limit is not None:
assert len(results) == (limit * len(test_subreddits))
downloader_mock.args.submitted = True
downloader_mock.args.user = test_user
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
results = RedditDownloader._get_user_data(downloader_mock)
results = assert_all_results_are_submissions(limit, results)
assert all([res.author.name == test_user for res in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip
def test_get_subreddits_search_bad():
raise NotImplementedError
def test_get_user_no_user(downloader_mock: MagicMock):
with pytest.raises(BulkDownloaderException):
RedditDownloader._get_user_data(downloader_mock)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip
def test_get_multireddits():
raise NotImplementedError
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip
def test_get_user_submissions():
raise NotImplementedError
@pytest.mark.parametrize('test_user', (
'rockcanopicjartheme',
'exceptionalcatfishracecarbatter',
))
def test_get_user_nonexistent_user(test_user: str, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.reddit_instance = reddit_instance
downloader_mock.args.user = test_user
downloader_mock._check_user_existence.return_value = RedditDownloader._check_user_existence(
downloader_mock, test_user)
with pytest.raises(RedditUserError):
RedditDownloader._get_user_data(downloader_mock)
@pytest.mark.online
@@ -239,6 +308,13 @@ def test_get_user_upvoted():
raise NotImplementedError
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip
def test_get_user_upvoted_unauthenticated():
raise NotImplementedError
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip
@@ -246,6 +322,13 @@ def test_get_user_saved():
raise NotImplementedError
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip
def test_get_user_saved_unauthenticated():
raise NotImplementedError
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skip