Add some more tests for RedditDownloader
This commit is contained in:
@@ -8,13 +8,15 @@ import socket
|
||||
from datetime import datetime
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import appdirs
|
||||
import praw
|
||||
import praw.models
|
||||
import prawcore
|
||||
|
||||
import bulkredditdownloader.errors as errors
|
||||
from bulkredditdownloader.download_filter import DownloadFilter
|
||||
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
|
||||
from bulkredditdownloader.file_name_formatter import FileNameFormatter
|
||||
from bulkredditdownloader.site_authenticator import SiteAuthenticator
|
||||
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
|
||||
@@ -54,6 +56,7 @@ class RedditDownloader:
|
||||
self.sort_filter = self._create_sort_filter()
|
||||
self.file_name_formatter = self._create_file_name_formatter()
|
||||
self.authenticator = self._create_authenticator()
|
||||
self._resolve_user_name()
|
||||
self._determine_directories()
|
||||
self._create_file_logger()
|
||||
self.master_hash_list = []
|
||||
@@ -118,6 +121,10 @@ class RedditDownloader:
|
||||
else:
|
||||
return []
|
||||
|
||||
def _resolve_user_name(self):
|
||||
if self.args.user == 'me':
|
||||
self.args.user = self.reddit_instance.user.me()
|
||||
|
||||
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||
supplied_submissions = []
|
||||
for sub_id in self.args.link:
|
||||
@@ -135,35 +142,52 @@ class RedditDownloader:
|
||||
sort_function = praw.models.Subreddit.hot
|
||||
return sort_function
|
||||
|
||||
def _get_multireddits(self) -> list[praw.models.ListingGenerator]:
|
||||
def _get_multireddits(self) -> list[Iterator]:
|
||||
if self.args.multireddit:
|
||||
if self.authenticated:
|
||||
return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in self.args.multireddit]
|
||||
if self.args.user:
|
||||
sort_function = self._determine_sort_function()
|
||||
return [
|
||||
sort_function(self.reddit_instance.multireddit(
|
||||
self.args.user,
|
||||
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
|
||||
else:
|
||||
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
|
||||
else:
|
||||
raise RedditAuthenticationError('Accessing multireddits requires authentication')
|
||||
raise errors.RedditAuthenticationError('Accessing multireddits requires authentication')
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_user_data(self) -> list[praw.models.ListingGenerator]:
|
||||
if any((self.args.upvoted, self.args.submitted, self.args.saved)):
|
||||
if self.authenticated:
|
||||
generators = []
|
||||
sort_function = self._determine_sort_function()
|
||||
def _get_user_data(self) -> list[Iterator]:
|
||||
if self.args.user:
|
||||
if not self._check_user_existence(self.args.user):
|
||||
raise errors.RedditUserError(f'User {self.args.user} does not exist')
|
||||
generators = []
|
||||
sort_function = self._determine_sort_function()
|
||||
if self.args.submitted:
|
||||
generators.append(
|
||||
sort_function(
|
||||
self.reddit_instance.redditor(self.args.user).submissions,
|
||||
limit=self.args.limit))
|
||||
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
||||
raise errors.RedditAuthenticationError('Accessing user lists requires authentication')
|
||||
else:
|
||||
if self.args.upvoted:
|
||||
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
|
||||
if self.args.submitted:
|
||||
generators.append(
|
||||
sort_function(
|
||||
self.reddit_instance.redditor(self.args.user).submissions,
|
||||
limit=self.args.limit))
|
||||
if self.args.saved:
|
||||
generators.append(self.reddit_instance.redditor(self.args.user).saved)
|
||||
|
||||
return generators
|
||||
else:
|
||||
raise RedditAuthenticationError('Accessing user lists requires authentication')
|
||||
return generators
|
||||
else:
|
||||
return []
|
||||
raise errors.BulkDownloaderException('A user must be supplied to download user data')
|
||||
|
||||
def _check_user_existence(self, name: str) -> bool:
|
||||
user = self.reddit_instance.redditor(name=name)
|
||||
try:
|
||||
if not user.id:
|
||||
return False
|
||||
except prawcore.exceptions.NotFound:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _create_file_name_formatter(self) -> FileNameFormatter:
|
||||
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
|
||||
@@ -198,10 +222,10 @@ class RedditDownloader:
|
||||
try:
|
||||
downloader_class = DownloadFactory.pull_lever(submission.url)
|
||||
downloader = downloader_class(submission)
|
||||
except NotADownloadableLinkError as e:
|
||||
except errors.NotADownloadableLinkError as e:
|
||||
logger.error('Could not download submission {}: {}'.format(submission.name, e))
|
||||
return
|
||||
|
||||
|
||||
if self.args.no_download:
|
||||
logger.info('Skipping download for submission {}'.format(submission.id))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user