Add some tests for RedditDownloader

This commit is contained in:
Serene-Arc
2021-03-03 12:53:53 +10:00
committed by Ali Parlakci
parent ea42471932
commit 9e6ec9f1ca
2 changed files with 279 additions and 15 deletions

View File

@@ -13,10 +13,10 @@ import appdirs
import praw
import praw.models
from bulkredditdownloader.site_authenticator import SiteAuthenticator
from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.site_authenticator import SiteAuthenticator
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
logger = logging.getLogger(__name__)
@@ -107,7 +107,11 @@ class RedditDownloader:
if self.args.subreddit:
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
if self.args.search:
return [reddit.search(self.args.search, sort=self.sort_filter.name.lower()) for reddit in subreddits]
return [
reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit) for reddit in subreddits]
else:
sort_function = self._determine_sort_function()
return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits]
@@ -116,8 +120,8 @@ class RedditDownloader:
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for url in self.args.link:
supplied_submissions.append(self.reddit_instance.submission(url=url))
for sub_id in self.args.link:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
return [supplied_submissions]
def _determine_sort_function(self):
@@ -162,29 +166,22 @@ class RedditDownloader:
return []
def _create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.set_filename, self.args.set_folderpath)
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
def _create_time_filter(self) -> RedditTypes.TimeType:
try:
return RedditTypes.TimeType[self.args.sort.upper()]
return RedditTypes.TimeType[self.args.time.upper()]
except (KeyError, AttributeError):
return RedditTypes.TimeType.ALL
def _create_sort_filter(self) -> RedditTypes.SortType:
try:
return RedditTypes.SortType[self.args.time.upper()]
return RedditTypes.SortType[self.args.sort.upper()]
except (KeyError, AttributeError):
return RedditTypes.SortType.HOT
def _create_download_filter(self) -> DownloadFilter:
formats = {
"videos": [".mp4", ".webm"],
"images": [".jpg", ".jpeg", ".png", ".bmp"],
"gifs": [".gif"],
"self": []
}
excluded_extensions = [extension for ext_type in self.args.skip for extension in formats.get(ext_type, ())]
return DownloadFilter(excluded_extensions, self.args.skip_domain)
return DownloadFilter(self.args.skip, self.args.skip_domain)
def _create_authenticator(self) -> SiteAuthenticator:
raise NotImplementedError