Add some more tests for RedditDownloader
This commit is contained in:
@@ -8,13 +8,15 @@ import socket
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
import appdirs
|
import appdirs
|
||||||
import praw
|
import praw
|
||||||
import praw.models
|
import praw.models
|
||||||
|
import prawcore
|
||||||
|
|
||||||
|
import bulkredditdownloader.errors as errors
|
||||||
from bulkredditdownloader.download_filter import DownloadFilter
|
from bulkredditdownloader.download_filter import DownloadFilter
|
||||||
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
|
|
||||||
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
||||||
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
||||||
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
|
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
|
||||||
@@ -54,6 +56,7 @@ class RedditDownloader:
|
|||||||
self.sort_filter = self._create_sort_filter()
|
self.sort_filter = self._create_sort_filter()
|
||||||
self.file_name_formatter = self._create_file_name_formatter()
|
self.file_name_formatter = self._create_file_name_formatter()
|
||||||
self.authenticator = self._create_authenticator()
|
self.authenticator = self._create_authenticator()
|
||||||
|
self._resolve_user_name()
|
||||||
self._determine_directories()
|
self._determine_directories()
|
||||||
self._create_file_logger()
|
self._create_file_logger()
|
||||||
self.master_hash_list = []
|
self.master_hash_list = []
|
||||||
@@ -118,6 +121,10 @@ class RedditDownloader:
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _resolve_user_name(self):
|
||||||
|
if self.args.user == 'me':
|
||||||
|
self.args.user = self.reddit_instance.user.me()
|
||||||
|
|
||||||
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||||
supplied_submissions = []
|
supplied_submissions = []
|
||||||
for sub_id in self.args.link:
|
for sub_id in self.args.link:
|
||||||
@@ -135,35 +142,52 @@ class RedditDownloader:
|
|||||||
sort_function = praw.models.Subreddit.hot
|
sort_function = praw.models.Subreddit.hot
|
||||||
return sort_function
|
return sort_function
|
||||||
|
|
||||||
def _get_multireddits(self) -> list[praw.models.ListingGenerator]:
|
def _get_multireddits(self) -> list[Iterator]:
|
||||||
if self.args.multireddit:
|
if self.args.multireddit:
|
||||||
if self.authenticated:
|
if self.authenticated:
|
||||||
return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in self.args.multireddit]
|
if self.args.user:
|
||||||
|
sort_function = self._determine_sort_function()
|
||||||
|
return [
|
||||||
|
sort_function(self.reddit_instance.multireddit(
|
||||||
|
self.args.user,
|
||||||
|
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
|
||||||
|
else:
|
||||||
|
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
|
||||||
else:
|
else:
|
||||||
raise RedditAuthenticationError('Accessing multireddits requires authentication')
|
raise errors.RedditAuthenticationError('Accessing multireddits requires authentication')
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_user_data(self) -> list[praw.models.ListingGenerator]:
|
def _get_user_data(self) -> list[Iterator]:
|
||||||
if any((self.args.upvoted, self.args.submitted, self.args.saved)):
|
if self.args.user:
|
||||||
if self.authenticated:
|
if not self._check_user_existence(self.args.user):
|
||||||
generators = []
|
raise errors.RedditUserError(f'User {self.args.user} does not exist')
|
||||||
sort_function = self._determine_sort_function()
|
generators = []
|
||||||
|
sort_function = self._determine_sort_function()
|
||||||
|
if self.args.submitted:
|
||||||
|
generators.append(
|
||||||
|
sort_function(
|
||||||
|
self.reddit_instance.redditor(self.args.user).submissions,
|
||||||
|
limit=self.args.limit))
|
||||||
|
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
||||||
|
raise errors.RedditAuthenticationError('Accessing user lists requires authentication')
|
||||||
|
else:
|
||||||
if self.args.upvoted:
|
if self.args.upvoted:
|
||||||
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
|
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
|
||||||
if self.args.submitted:
|
|
||||||
generators.append(
|
|
||||||
sort_function(
|
|
||||||
self.reddit_instance.redditor(self.args.user).submissions,
|
|
||||||
limit=self.args.limit))
|
|
||||||
if self.args.saved:
|
if self.args.saved:
|
||||||
generators.append(self.reddit_instance.redditor(self.args.user).saved)
|
generators.append(self.reddit_instance.redditor(self.args.user).saved)
|
||||||
|
return generators
|
||||||
return generators
|
|
||||||
else:
|
|
||||||
raise RedditAuthenticationError('Accessing user lists requires authentication')
|
|
||||||
else:
|
else:
|
||||||
return []
|
raise errors.BulkDownloaderException('A user must be supplied to download user data')
|
||||||
|
|
||||||
|
def _check_user_existence(self, name: str) -> bool:
|
||||||
|
user = self.reddit_instance.redditor(name=name)
|
||||||
|
try:
|
||||||
|
if not user.id:
|
||||||
|
return False
|
||||||
|
except prawcore.exceptions.NotFound:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _create_file_name_formatter(self) -> FileNameFormatter:
|
def _create_file_name_formatter(self) -> FileNameFormatter:
|
||||||
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
|
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
|
||||||
@@ -198,10 +222,10 @@ class RedditDownloader:
|
|||||||
try:
|
try:
|
||||||
downloader_class = DownloadFactory.pull_lever(submission.url)
|
downloader_class = DownloadFactory.pull_lever(submission.url)
|
||||||
downloader = downloader_class(submission)
|
downloader = downloader_class(submission)
|
||||||
except NotADownloadableLinkError as e:
|
except errors.NotADownloadableLinkError as e:
|
||||||
logger.error('Could not download submission {}: {}'.format(submission.name, e))
|
logger.error('Could not download submission {}: {}'.format(submission.name, e))
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.args.no_download:
|
if self.args.no_download:
|
||||||
logger.info('Skipping download for submission {}'.format(submission.id))
|
logger.info('Skipping download for submission {}'.format(submission.id))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import praw
|
import praw
|
||||||
@@ -11,7 +12,7 @@ import pytest
|
|||||||
|
|
||||||
from bulkredditdownloader.download_filter import DownloadFilter
|
from bulkredditdownloader.download_filter import DownloadFilter
|
||||||
from bulkredditdownloader.downloader import RedditDownloader, RedditTypes
|
from bulkredditdownloader.downloader import RedditDownloader, RedditTypes
|
||||||
from bulkredditdownloader.errors import BulkDownloaderException
|
from bulkredditdownloader.errors import BulkDownloaderException, RedditAuthenticationError, RedditUserError
|
||||||
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
||||||
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
||||||
|
|
||||||
@@ -25,6 +26,7 @@ def args() -> argparse.Namespace:
|
|||||||
args.link = []
|
args.link = []
|
||||||
args.submitted = False
|
args.submitted = False
|
||||||
args.upvoted = False
|
args.upvoted = False
|
||||||
|
args.saved = False
|
||||||
args.subreddit = []
|
args.subreddit = []
|
||||||
args.multireddit = []
|
args.multireddit = []
|
||||||
args.user = None
|
args.user = None
|
||||||
@@ -48,6 +50,14 @@ def downloader_mock(args: argparse.Namespace):
|
|||||||
return mock_downloader
|
return mock_downloader
|
||||||
|
|
||||||
|
|
||||||
|
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
|
||||||
|
results = [sub for res in results for sub in res]
|
||||||
|
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||||
|
if result_limit is not None:
|
||||||
|
assert len(results) == result_limit
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
|
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
|
||||||
downloader_mock.args.directory = tmp_path / 'test'
|
downloader_mock.args.directory = tmp_path / 'test'
|
||||||
RedditDownloader._determine_directories(downloader_mock)
|
RedditDownloader._determine_directories(downloader_mock)
|
||||||
@@ -172,17 +182,15 @@ def test_get_subreddit_normal(
|
|||||||
limit: int,
|
limit: int,
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
reddit_instance: praw.Reddit):
|
reddit_instance: praw.Reddit):
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
|
||||||
downloader_mock.args.limit = limit
|
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||||
results = [sub for res in results for sub in res]
|
results = assert_all_results_are_submissions(
|
||||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
(limit * len(test_subreddits)) if limit else None, results)
|
||||||
assert all([res.subreddit.display_name for res in results])
|
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||||
if limit is not None:
|
|
||||||
assert len(results) == (limit * len(test_subreddits))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@@ -190,6 +198,7 @@ def test_get_subreddit_normal(
|
|||||||
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
|
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
|
||||||
(('Python',), 'scraper', 10),
|
(('Python',), 'scraper', 10),
|
||||||
(('Python',), '', 10),
|
(('Python',), '', 10),
|
||||||
|
(('Python',), 'djsdsgewef', 0),
|
||||||
))
|
))
|
||||||
def test_get_subreddit_search(
|
def test_get_subreddit_search(
|
||||||
test_subreddits: list[str],
|
test_subreddits: list[str],
|
||||||
@@ -197,39 +206,99 @@ def test_get_subreddit_search(
|
|||||||
limit: int,
|
limit: int,
|
||||||
downloader_mock: MagicMock,
|
downloader_mock: MagicMock,
|
||||||
reddit_instance: praw.Reddit):
|
reddit_instance: praw.Reddit):
|
||||||
downloader_mock.reddit_instance = reddit_instance
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock.args.search = search_term
|
||||||
downloader_mock.args.subreddit = test_subreddits
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||||
|
results = assert_all_results_are_submissions(
|
||||||
|
(limit * len(test_subreddits)) if limit else None, results)
|
||||||
|
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
|
||||||
|
('helen_darten', ('cuteanimalpics',), 10),
|
||||||
|
('korfor', ('chess',), 100),
|
||||||
|
))
|
||||||
|
# Good sources at https://www.reddit.com/r/multihub/
|
||||||
|
def test_get_multireddits_public(
|
||||||
|
test_user: str,
|
||||||
|
test_multireddits: list[str],
|
||||||
|
limit: int,
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
downloader_mock: MagicMock):
|
||||||
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock.args.multireddit = test_multireddits
|
||||||
|
downloader_mock.args.user = test_user
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
results = RedditDownloader._get_multireddits(downloader_mock)
|
||||||
|
assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
def test_get_multireddits_no_user(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||||
|
downloader_mock.args.multireddit = ['test']
|
||||||
|
with pytest.raises(BulkDownloaderException):
|
||||||
|
RedditDownloader._get_multireddits(downloader_mock)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
def test_get_multireddits_not_authenticated(downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||||
|
downloader_mock.args.multireddit = ['test']
|
||||||
|
downloader_mock.authenticated = False
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
with pytest.raises(RedditAuthenticationError):
|
||||||
|
RedditDownloader._get_multireddits(downloader_mock)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_user', 'limit'), (
|
||||||
|
('danigirl3694', 10),
|
||||||
|
('danigirl3694', 50),
|
||||||
|
('CapitanHam', None),
|
||||||
|
))
|
||||||
|
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||||
downloader_mock.args.limit = limit
|
downloader_mock.args.limit = limit
|
||||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
downloader_mock.args.search = search_term
|
downloader_mock.args.submitted = True
|
||||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
downloader_mock.args.user = test_user
|
||||||
results = [sub for res in results for sub in res]
|
downloader_mock.authenticated = False
|
||||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
assert all([res.subreddit.display_name for res in results])
|
results = RedditDownloader._get_user_data(downloader_mock)
|
||||||
if limit is not None:
|
results = assert_all_results_are_submissions(limit, results)
|
||||||
assert len(results) == (limit * len(test_subreddits))
|
assert all([res.author.name == test_user for res in results])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.skip
|
def test_get_user_no_user(downloader_mock: MagicMock):
|
||||||
def test_get_subreddits_search_bad():
|
with pytest.raises(BulkDownloaderException):
|
||||||
raise NotImplementedError
|
RedditDownloader._get_user_data(downloader_mock)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.skip
|
@pytest.mark.parametrize('test_user', (
|
||||||
def test_get_multireddits():
|
'rockcanopicjartheme',
|
||||||
raise NotImplementedError
|
'exceptionalcatfishracecarbatter',
|
||||||
|
))
|
||||||
|
def test_get_user_nonexistent_user(test_user: str, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||||
@pytest.mark.online
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
@pytest.mark.reddit
|
downloader_mock.args.user = test_user
|
||||||
@pytest.mark.skip
|
downloader_mock._check_user_existence.return_value = RedditDownloader._check_user_existence(
|
||||||
def test_get_user_submissions():
|
downloader_mock, test_user)
|
||||||
raise NotImplementedError
|
with pytest.raises(RedditUserError):
|
||||||
|
RedditDownloader._get_user_data(downloader_mock)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@@ -239,6 +308,13 @@ def test_get_user_upvoted():
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_get_user_upvoted_unauthenticated():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@@ -246,6 +322,13 @@ def test_get_user_saved():
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_get_user_saved_unauthenticated():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
@pytest.mark.reddit
|
@pytest.mark.reddit
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
|
|||||||
Reference in New Issue
Block a user