Add some more tests for RedditDownloader

This commit is contained in:
Serene-Arc
2021-03-05 13:31:40 +10:00
committed by Ali Parlakci
parent 6f86dbd552
commit b705c31630
2 changed files with 158 additions and 51 deletions

View File

@@ -8,13 +8,15 @@ import socket
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import Iterator
import appdirs
import praw
import praw.models
import prawcore
import bulkredditdownloader.errors as errors
from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.errors import NotADownloadableLinkError, RedditAuthenticationError
from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.site_authenticator import SiteAuthenticator
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
@@ -54,6 +56,7 @@ class RedditDownloader:
self.sort_filter = self._create_sort_filter()
self.file_name_formatter = self._create_file_name_formatter()
self.authenticator = self._create_authenticator()
self._resolve_user_name()
self._determine_directories()
self._create_file_logger()
self.master_hash_list = []
@@ -118,6 +121,10 @@ class RedditDownloader:
else:
return []
def _resolve_user_name(self):
if self.args.user == 'me':
self.args.user = self.reddit_instance.user.me()
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
@@ -135,35 +142,52 @@ class RedditDownloader:
sort_function = praw.models.Subreddit.hot
return sort_function
def _get_multireddits(self) -> list[praw.models.ListingGenerator]:
def _get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
if self.authenticated:
return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in self.args.multireddit]
if self.args.user:
sort_function = self._determine_sort_function()
return [
sort_function(self.reddit_instance.multireddit(
self.args.user,
m_reddit_choice), limit=self.args.limit) for m_reddit_choice in self.args.multireddit]
else:
raise errors.BulkDownloaderException('A user must be provided to download a multireddit')
else:
raise RedditAuthenticationError('Accessing multireddits requires authentication')
raise errors.RedditAuthenticationError('Accessing multireddits requires authentication')
else:
return []
def _get_user_data(self) -> list[praw.models.ListingGenerator]:
if any((self.args.upvoted, self.args.submitted, self.args.saved)):
if self.authenticated:
generators = []
sort_function = self._determine_sort_function()
def _get_user_data(self) -> list[Iterator]:
if self.args.user:
if not self._check_user_existence(self.args.user):
raise errors.RedditUserError(f'User {self.args.user} does not exist')
generators = []
sort_function = self._determine_sort_function()
if self.args.submitted:
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
raise errors.RedditAuthenticationError('Accessing user lists requires authentication')
else:
if self.args.upvoted:
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
if self.args.submitted:
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit))
if self.args.saved:
generators.append(self.reddit_instance.redditor(self.args.user).saved)
return generators
else:
raise RedditAuthenticationError('Accessing user lists requires authentication')
return generators
else:
return []
raise errors.BulkDownloaderException('A user must be supplied to download user data')
def _check_user_existence(self, name: str) -> bool:
user = self.reddit_instance.redditor(name=name)
try:
if not user.id:
return False
except prawcore.exceptions.NotFound:
return False
return True
def _create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.set_file_scheme, self.args.set_folder_scheme)
@@ -198,10 +222,10 @@ class RedditDownloader:
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
except NotADownloadableLinkError as e:
except errors.NotADownloadableLinkError as e:
logger.error('Could not download submission {}: {}'.format(submission.name, e))
return
if self.args.no_download:
logger.info('Skipping download for submission {}'.format(submission.id))
else: