Add some tests for RedditDownloader
This commit is contained in:
@@ -13,10 +13,10 @@ import appdirs
|
||||
import praw
|
||||
import praw.models
|
||||
|
||||
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
||||
from bulkredditdownloader.download_filter import DownloadFilter
|
||||
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
|
||||
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
||||
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
||||
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -107,7 +107,11 @@ class RedditDownloader:
|
||||
if self.args.subreddit:
|
||||
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
|
||||
if self.args.search:
|
||||
return [reddit.search(self.args.search, sort=self.sort_filter.name.lower()) for reddit in subreddits]
|
||||
return [
|
||||
reddit.search(
|
||||
self.args.search,
|
||||
sort=self.sort_filter.name.lower(),
|
||||
limit=self.args.limit) for reddit in subreddits]
|
||||
else:
|
||||
sort_function = self._determine_sort_function()
|
||||
return [sort_function(reddit, limit=self.args.limit) for reddit in subreddits]
|
||||
@@ -116,8 +120,8 @@ class RedditDownloader:
|
||||
|
||||
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||
supplied_submissions = []
|
||||
for url in self.args.link:
|
||||
supplied_submissions.append(self.reddit_instance.submission(url=url))
|
||||
for sub_id in self.args.link:
|
||||
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
|
||||
return [supplied_submissions]
|
||||
|
||||
def _determine_sort_function(self):
|
||||
@@ -162,29 +166,22 @@ class RedditDownloader:
|
||||
return []
|
||||
|
||||
def _create_file_name_formatter(self) -> FileNameFormatter:
|
||||
return FileNameFormatter(self.args.set_filename, self.args.set_folderpath)
|
||||
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
|
||||
|
||||
def _create_time_filter(self) -> RedditTypes.TimeType:
|
||||
try:
|
||||
return RedditTypes.TimeType[self.args.sort.upper()]
|
||||
return RedditTypes.TimeType[self.args.time.upper()]
|
||||
except (KeyError, AttributeError):
|
||||
return RedditTypes.TimeType.ALL
|
||||
|
||||
def _create_sort_filter(self) -> RedditTypes.SortType:
|
||||
try:
|
||||
return RedditTypes.SortType[self.args.time.upper()]
|
||||
return RedditTypes.SortType[self.args.sort.upper()]
|
||||
except (KeyError, AttributeError):
|
||||
return RedditTypes.SortType.HOT
|
||||
|
||||
def _create_download_filter(self) -> DownloadFilter:
|
||||
formats = {
|
||||
"videos": [".mp4", ".webm"],
|
||||
"images": [".jpg", ".jpeg", ".png", ".bmp"],
|
||||
"gifs": [".gif"],
|
||||
"self": []
|
||||
}
|
||||
excluded_extensions = [extension for ext_type in self.args.skip for extension in formats.get(ext_type, ())]
|
||||
return DownloadFilter(excluded_extensions, self.args.skip_domain)
|
||||
return DownloadFilter(self.args.skip, self.args.skip_domain)
|
||||
|
||||
def _create_authenticator(self) -> SiteAuthenticator:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user