Add some tests for RedditDownloader
This commit is contained in:
@@ -13,10 +13,10 @@ import appdirs
|
|||||||
import praw
|
import praw
|
||||||
import praw.models
|
import praw.models
|
||||||
|
|
||||||
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
|
||||||
from bulkredditdownloader.download_filter import DownloadFilter
|
from bulkredditdownloader.download_filter import DownloadFilter
|
||||||
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
|
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
|
||||||
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
||||||
|
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
||||||
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
|
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -107,7 +107,11 @@ class RedditDownloader:
|
|||||||
if self.args.subreddit:
|
if self.args.subreddit:
|
||||||
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
|
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
|
||||||
if self.args.search:
|
if self.args.search:
|
||||||
return [reddit.search(self.args.search, sort=self.sort_filter.name.lower()) for reddit in subreddits]
|
return [
|
||||||
|
reddit.search(
|
||||||
|
self.args.search,
|
||||||
|
sort=self.sort_filter.name.lower(),
|
||||||
|
limit=self.args.limit) for reddit in subreddits]
|
||||||
else:
|
else:
|
||||||
sort_function = self._determine_sort_function()
|
sort_function = self._determine_sort_function()
|
||||||
return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits]
|
return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits]
|
||||||
@@ -116,8 +120,8 @@ class RedditDownloader:
|
|||||||
|
|
||||||
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||||
supplied_submissions = []
|
supplied_submissions = []
|
||||||
for url in self.args.link:
|
for sub_id in self.args.link:
|
||||||
supplied_submissions.append(self.reddit_instance.submission(url=url))
|
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
|
||||||
return [supplied_submissions]
|
return [supplied_submissions]
|
||||||
|
|
||||||
def _determine_sort_function(self):
|
def _determine_sort_function(self):
|
||||||
@@ -162,29 +166,22 @@ class RedditDownloader:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def _create_file_name_formatter(self) -> FileNameFormatter:
|
def _create_file_name_formatter(self) -> FileNameFormatter:
|
||||||
return FileNameFormatter(self.args.set_filename, self.args.set_folderpath)
|
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
|
||||||
|
|
||||||
def _create_time_filter(self) -> RedditTypes.TimeType:
|
def _create_time_filter(self) -> RedditTypes.TimeType:
|
||||||
try:
|
try:
|
||||||
return RedditTypes.TimeType[self.args.sort.upper()]
|
return RedditTypes.TimeType[self.args.time.upper()]
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
return RedditTypes.TimeType.ALL
|
return RedditTypes.TimeType.ALL
|
||||||
|
|
||||||
def _create_sort_filter(self) -> RedditTypes.SortType:
|
def _create_sort_filter(self) -> RedditTypes.SortType:
|
||||||
try:
|
try:
|
||||||
return RedditTypes.SortType[self.args.time.upper()]
|
return RedditTypes.SortType[self.args.sort.upper()]
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
return RedditTypes.SortType.HOT
|
return RedditTypes.SortType.HOT
|
||||||
|
|
||||||
def _create_download_filter(self) -> DownloadFilter:
|
def _create_download_filter(self) -> DownloadFilter:
|
||||||
formats = {
|
return DownloadFilter(self.args.skip, self.args.skip_domain)
|
||||||
"videos": [".mp4", ".webm"],
|
|
||||||
"images": [".jpg", ".jpeg", ".png", ".bmp"],
|
|
||||||
"gifs": [".gif"],
|
|
||||||
"self": []
|
|
||||||
}
|
|
||||||
excluded_extensions = [extension for ext_type in self.args.skip for extension in formats.get(ext_type, ())]
|
|
||||||
return DownloadFilter(excluded_extensions, self.args.skip_domain)
|
|
||||||
|
|
||||||
def _create_authenticator(self) -> SiteAuthenticator:
|
def _create_authenticator(self) -> SiteAuthenticator:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
267
bulkredditdownloader/tests/test_downloader.py
Normal file
267
bulkredditdownloader/tests/test_downloader.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import praw
|
||||||
|
import praw.models
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from bulkredditdownloader.download_filter import DownloadFilter
|
||||||
|
from bulkredditdownloader.downloader import RedditDownloader, RedditTypes
|
||||||
|
from bulkredditdownloader.errors import BulkDownloaderException
|
||||||
|
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
||||||
|
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def args() -> argparse.Namespace:
|
||||||
|
args = argparse.Namespace()
|
||||||
|
|
||||||
|
args.directory = '.'
|
||||||
|
args.verbose = 0
|
||||||
|
args.link = []
|
||||||
|
args.submitted = False
|
||||||
|
args.upvoted = False
|
||||||
|
args.subreddit = []
|
||||||
|
args.multireddit = []
|
||||||
|
args.user = None
|
||||||
|
args.search = None
|
||||||
|
args.sort = 'hot'
|
||||||
|
args.limit = None
|
||||||
|
args.time = 'all'
|
||||||
|
args.skip = []
|
||||||
|
args.skip_domain = []
|
||||||
|
args.set_folder_scheme = '{SUBREDDIT}'
|
||||||
|
args.set_file_scheme = '{REDDITOR}_{TITLE}_{POSTID}'
|
||||||
|
args.no_dupes = False
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def downloader_mock(args: argparse.Namespace):
|
||||||
|
mock_downloader = MagicMock()
|
||||||
|
mock_downloader.args = args
|
||||||
|
return mock_downloader
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.directory = tmp_path / 'test'
|
||||||
|
RedditDownloader._determine_directories(downloader_mock)
|
||||||
|
|
||||||
|
assert Path(tmp_path / 'test').exists()
|
||||||
|
assert downloader_mock.logfile_directory == Path(tmp_path / 'test' / 'LOG_FILES')
|
||||||
|
assert downloader_mock.logfile_directory.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
|
||||||
|
([], []),
|
||||||
|
(['.test'], ['test.com']),
|
||||||
|
))
|
||||||
|
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.skip = skip_extensions
|
||||||
|
downloader_mock.args.skip_domain = skip_domains
|
||||||
|
result = RedditDownloader._create_download_filter(downloader_mock)
|
||||||
|
|
||||||
|
assert isinstance(result, DownloadFilter)
|
||||||
|
assert result.excluded_domains == skip_domains
|
||||||
|
assert result.excluded_extensions == skip_extensions
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_time', 'expected'), (
|
||||||
|
('all', 'all'),
|
||||||
|
('hour', 'hour'),
|
||||||
|
('day', 'day'),
|
||||||
|
('week', 'week'),
|
||||||
|
('random', 'all'),
|
||||||
|
('', 'all'),
|
||||||
|
))
|
||||||
|
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.time = test_time
|
||||||
|
result = RedditDownloader._create_time_filter(downloader_mock)
|
||||||
|
|
||||||
|
assert isinstance(result, RedditTypes.TimeType)
|
||||||
|
assert result.name.lower() == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_sort', 'expected'), (
|
||||||
|
('', 'hot'),
|
||||||
|
('hot', 'hot'),
|
||||||
|
('controversial', 'controversial'),
|
||||||
|
('new', 'new'),
|
||||||
|
))
|
||||||
|
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.sort = test_sort
|
||||||
|
result = RedditDownloader._create_sort_filter(downloader_mock)
|
||||||
|
|
||||||
|
assert isinstance(result, RedditTypes.SortType)
|
||||||
|
assert result.name.lower() == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
|
||||||
|
('{POSTID}', '{SUBREDDIT}'),
|
||||||
|
('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
|
||||||
|
))
|
||||||
|
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.set_file_scheme = test_file_scheme
|
||||||
|
downloader_mock.args.set_folder_scheme = test_folder_scheme
|
||||||
|
result = RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||||
|
|
||||||
|
assert isinstance(result, FileNameFormatter)
|
||||||
|
assert result.file_format_string == test_file_scheme
|
||||||
|
assert result.directory_format_string == test_folder_scheme
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
|
||||||
|
('', ''),
|
||||||
|
('{POSTID}', ''),
|
||||||
|
('', '{SUBREDDIT}'),
|
||||||
|
('test', '{SUBREDDIT}'),
|
||||||
|
('{POSTID}', 'test'),
|
||||||
|
))
|
||||||
|
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.set_file_scheme = test_file_scheme
|
||||||
|
downloader_mock.args.set_folder_scheme = test_folder_scheme
|
||||||
|
with pytest.raises(BulkDownloaderException):
|
||||||
|
RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_create_authenticator(downloader_mock: MagicMock):
|
||||||
|
result = RedditDownloader._create_authenticator(downloader_mock)
|
||||||
|
assert isinstance(result, SiteAuthenticator)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize('test_submission_ids', (
|
||||||
|
('lvpf4l',),
|
||||||
|
('lvpf4l', 'lvqnsn'),
|
||||||
|
('lvpf4l', 'lvqnsn', 'lvl9kd'),
|
||||||
|
))
|
||||||
|
def test_get_submissions_from_link(
|
||||||
|
test_submission_ids: list[str],
|
||||||
|
reddit_instance: praw.Reddit,
|
||||||
|
downloader_mock: MagicMock):
|
||||||
|
downloader_mock.args.link = test_submission_ids
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
results = RedditDownloader._get_submissions_from_link(downloader_mock)
|
||||||
|
assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
|
||||||
|
assert len(results[0]) == len(test_submission_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_load_config(downloader_mock: MagicMock):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
|
||||||
|
(('Futurology',), 10),
|
||||||
|
(('Futurology',), 20),
|
||||||
|
(('Futurology', 'Python'), 10),
|
||||||
|
(('Futurology',), 100),
|
||||||
|
(('Futurology',), 0),
|
||||||
|
))
|
||||||
|
def test_get_subreddit_normal(
|
||||||
|
test_subreddits: list[str],
|
||||||
|
limit: int,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
reddit_instance: praw.Reddit):
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||||
|
results = [sub for res in results for sub in res]
|
||||||
|
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||||
|
assert all([res.subreddit.display_name for res in results])
|
||||||
|
if limit is not None:
|
||||||
|
assert len(results) == (limit * len(test_subreddits))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
|
||||||
|
(('Python',), 'scraper', 10),
|
||||||
|
(('Python',), '', 10),
|
||||||
|
))
|
||||||
|
def test_get_subreddit_search(
|
||||||
|
test_subreddits: list[str],
|
||||||
|
search_term: str,
|
||||||
|
limit: int,
|
||||||
|
downloader_mock: MagicMock,
|
||||||
|
reddit_instance: praw.Reddit):
|
||||||
|
downloader_mock.reddit_instance = reddit_instance
|
||||||
|
downloader_mock.args.subreddit = test_subreddits
|
||||||
|
downloader_mock.args.limit = limit
|
||||||
|
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||||
|
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||||
|
downloader_mock.args.search = search_term
|
||||||
|
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||||
|
results = [sub for res in results for sub in res]
|
||||||
|
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||||
|
assert all([res.subreddit.display_name for res in results])
|
||||||
|
if limit is not None:
|
||||||
|
assert len(results) == (limit * len(test_subreddits))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_get_subreddits_search_bad():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_get_multireddits():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_get_user_submissions():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_get_user_upvoted():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_get_user_saved():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_download_submission():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_download_submission_file_exists():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
|
@pytest.mark.reddit
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_download_submission_hash_exists():
|
||||||
|
raise NotImplementedError
|
||||||
Reference in New Issue
Block a user