Rename module

This commit is contained in:
Serene-Arc
2021-04-12 17:58:32 +10:00
committed by Ali Parlakci
parent 3da58dbd5d
commit bd9f276acc
57 changed files with 139 additions and 139 deletions

0
bdfr/__init__.py Normal file
View File

117
bdfr/__main__.py Normal file
View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python3
import logging
import sys
import click
from bdfr.archiver import Archiver
from bdfr.configuration import Configuration
from bdfr.downloader import RedditDownloader
logger = logging.getLogger()
_common_options = [
click.argument('directory', type=str),
click.option('--config', type=str, default=None),
click.option('-v', '--verbose', default=None, count=True),
click.option('-l', '--link', multiple=True, default=None, type=str),
click.option('-s', '--subreddit', multiple=True, default=None, type=str),
click.option('-m', '--multireddit', multiple=True, default=None, type=str),
click.option('-L', '--limit', default=None, type=int),
click.option('--authenticate', is_flag=True, default=None),
click.option('--submitted', is_flag=True, default=None),
click.option('--upvoted', is_flag=True, default=None),
click.option('--saved', is_flag=True, default=None),
click.option('--search', default=None, type=str),
click.option('-u', '--user', type=str, default=None),
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new',
'controversial', 'rising', 'relevance')), default=None),
]
def _add_common_options(func):
for opt in _common_options:
func = opt(func)
return func
@click.group()
def cli():
pass
@cli.command('download')
@click.option('--exclude-id', default=None, multiple=True)
@click.option('--exclude-id-file', default=None, multiple=True)
@click.option('--file-scheme', default=None, type=str)
@click.option('--folder-scheme', default=None, type=str)
@click.option('--make-hard-links', is_flag=True, default=None)
@click.option('--max-wait-time', type=int, default=None)
@click.option('--no-dupes', is_flag=True, default=None)
@click.option('--search-existing', is_flag=True, default=None)
@click.option('--skip', default=None, multiple=True)
@click.option('--skip-domain', default=None, multiple=True)
@_add_common_options
@click.pass_context
def cli_download(context: click.Context, **_):
config = Configuration()
config.process_click_arguments(context)
setup_logging(config.verbose)
try:
reddit_downloader = RedditDownloader(config)
reddit_downloader.download()
except Exception:
logger.exception('Downloader exited unexpectedly')
raise
else:
logger.info('Program complete')
@cli.command('archive')
@_add_common_options
@click.option('--all-comments', is_flag=True, default=None)
@click.option('-f,', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None)
@click.pass_context
def cli_archive(context: click.Context, **_):
config = Configuration()
config.process_click_arguments(context)
setup_logging(config.verbose)
try:
reddit_archiver = Archiver(config)
reddit_archiver.download()
except Exception:
logger.exception('Downloader exited unexpectedly')
raise
else:
logger.info('Program complete')
def setup_logging(verbosity: int):
class StreamExceptionFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
result = not (record.levelno == logging.ERROR and record.exc_info)
return result
logger.setLevel(1)
stream = logging.StreamHandler(sys.stdout)
stream.addFilter(StreamExceptionFilter())
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
stream.setFormatter(formatter)
logger.addHandler(stream)
if verbosity <= 0:
stream.setLevel(logging.INFO)
elif verbosity == 1:
stream.setLevel(logging.DEBUG)
else:
stream.setLevel(9)
logging.getLogger('praw').setLevel(logging.CRITICAL)
logging.getLogger('prawcore').setLevel(logging.CRITICAL)
logging.getLogger('urllib3').setLevel(logging.CRITICAL)
if __name__ == '__main__':
cli()

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# coding=utf-8

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# coding=utf-8
from abc import ABC, abstractmethod
from praw.models import Comment, Submission
class BaseArchiveEntry(ABC):
def __init__(self, source: (Comment, Submission)):
self.source = source
self.post_details: dict = {}
@abstractmethod
def compile(self) -> dict:
raise NotImplementedError
@staticmethod
def _convert_comment_to_dict(in_comment: Comment) -> dict:
out_dict = {
'author': in_comment.author.name if in_comment.author else 'DELETED',
'id': in_comment.id,
'score': in_comment.score,
'subreddit': in_comment.subreddit.display_name,
'submission': in_comment.submission.id,
'stickied': in_comment.stickied,
'body': in_comment.body,
'is_submitter': in_comment.is_submitter,
'created_utc': in_comment.created_utc,
'parent_id': in_comment.parent_id,
'replies': [],
}
in_comment.replies.replace_more(0)
for reply in in_comment.replies:
out_dict['replies'].append(BaseArchiveEntry._convert_comment_to_dict(reply))
return out_dict

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env python3
# coding=utf-8
import logging
import praw.models
from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
logger = logging.getLogger(__name__)
class CommentArchiveEntry(BaseArchiveEntry):
def __init__(self, comment: praw.models.Comment):
super(CommentArchiveEntry, self).__init__(comment)
def compile(self) -> dict:
self.source.refresh()
self.post_details = self._convert_comment_to_dict(self.source)
self.post_details['submission_title'] = self.source.submission.title
return self.post_details

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# coding=utf-8
import logging
import praw.models
from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
logger = logging.getLogger(__name__)
class SubmissionArchiveEntry(BaseArchiveEntry):
def __init__(self, submission: praw.models.Submission):
super(SubmissionArchiveEntry, self).__init__(submission)
def compile(self) -> dict:
comments = self._get_comments()
self._get_post_details()
out = self.post_details
out['comments'] = comments
return out
def _get_post_details(self):
self.post_details = {
'title': self.source.title,
'name': self.source.name,
'url': self.source.url,
'selftext': self.source.selftext,
'score': self.source.score,
'upvote_ratio': self.source.upvote_ratio,
'permalink': self.source.permalink,
'id': self.source.id,
'author': self.source.author.name if self.source.author else 'DELETED',
'link_flair_text': self.source.link_flair_text,
'num_comments': self.source.num_comments,
'over_18': self.source.over_18,
'created_utc': self.source.created_utc,
}
def _get_comments(self) -> list[dict]:
logger.debug(f'Retrieving full comment tree for submission {self.source.id}')
comments = []
self.source.comments.replace_more(0)
for top_level_comment in self.source.comments:
comments.append(self._convert_comment_to_dict(top_level_comment))
return comments

96
bdfr/archiver.py Normal file
View File

@@ -0,0 +1,96 @@
#!/usr/bin/env python3
# coding=utf-8
import json
import logging
import re
from typing import Iterator
import dict2xml
import praw.models
import yaml
from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
from bdfr.configuration import Configuration
from bdfr.downloader import RedditDownloader
from bdfr.exceptions import ArchiverError
from bdfr.resource import Resource
logger = logging.getLogger(__name__)
class Archiver(RedditDownloader):
def __init__(self, args: Configuration):
super(Archiver, self).__init__(args)
def download(self):
for generator in self.reddit_lists:
for submission in generator:
logger.debug(f'Attempting to archive submission {submission.id}')
self._write_entry(submission)
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
if len(sub_id) == 6:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
elif re.match(r'^\w{7}$', sub_id):
supplied_submissions.append(self.reddit_instance.comment(id=sub_id))
else:
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
return [supplied_submissions]
def _get_user_data(self) -> list[Iterator]:
results = super(Archiver, self)._get_user_data()
if self.args.user and self.args.all_comments:
sort = self._determine_sort_function()
logger.debug(f'Retrieving comments of user {self.args.user}')
results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit))
return results
@staticmethod
def _pull_lever_entry_factory(praw_item: (praw.models.Submission, praw.models.Comment)) -> BaseArchiveEntry:
if isinstance(praw_item, praw.models.Submission):
return SubmissionArchiveEntry(praw_item)
elif isinstance(praw_item, praw.models.Comment):
return CommentArchiveEntry(praw_item)
else:
raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}')
def _write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)):
archive_entry = self._pull_lever_entry_factory(praw_item)
if self.args.format == 'json':
self._write_entry_json(archive_entry)
elif self.args.format == 'xml':
self._write_entry_xml(archive_entry)
elif self.args.format == 'yaml':
self._write_entry_yaml(archive_entry)
else:
raise ArchiverError(f'Unknown format {self.args.format} given')
logger.info(f'Record for entry item {praw_item.id} written to disk')
def _write_entry_json(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', '.json')
content = json.dumps(entry.compile())
self._write_content_to_disk(resource, content)
def _write_entry_xml(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', '.xml')
content = dict2xml.dict2xml(entry.compile(), wrap='root')
self._write_content_to_disk(resource, content)
def _write_entry_yaml(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', '.yaml')
content = yaml.dump(entry.compile())
self._write_content_to_disk(resource, content)
def _write_content_to_disk(self, resource: Resource, content: str):
file_path = self.file_name_formatter.format_path(resource, self.download_directory)
file_path.parent.mkdir(exist_ok=True, parents=True)
with open(file_path, 'w') as file:
logger.debug(
f'Writing entry {resource.source_submission.id} to file in {resource.extension[1:].upper()}'
f' format at {file_path}')
file.write(content)

46
bdfr/configuration.py Normal file
View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
# coding=utf-8
from argparse import Namespace
from typing import Optional
import click
class Configuration(Namespace):
def __init__(self):
super(Configuration, self).__init__()
self.authenticate = False
self.config = None
self.directory: str = '.'
self.exclude_id = []
self.exclude_id_file = []
self.limit: Optional[int] = None
self.link: list[str] = []
self.max_wait_time = None
self.multireddit: list[str] = []
self.no_dupes: bool = False
self.saved: bool = False
self.search: Optional[str] = None
self.search_existing: bool = False
self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}'
self.folder_scheme: str = '{SUBREDDIT}'
self.skip: list[str] = []
self.skip_domain: list[str] = []
self.sort: str = 'hot'
self.submitted: bool = False
self.subreddit: list[str] = []
self.time: str = 'all'
self.upvoted: bool = False
self.user: Optional[str] = None
self.verbose: int = 0
self.make_hard_links = False
# Archiver-specific options
self.format = 'json'
self.all_comments = False
def process_click_arguments(self, context: click.Context):
for arg_key in context.params.keys():
if arg_key in vars(self) and context.params[arg_key] is not None:
vars(self)[arg_key] = context.params[arg_key]

6
bdfr/default_config.cfg Normal file
View File

@@ -0,0 +1,6 @@
[DEFAULT]
client_id = U-6gk4ZCh3IeNQ
client_secret = 7CZHY6AmKweZME5s50SfDGylaPg
scopes = identity, history, read, save
backup_log_count = 3
max_wait_time = 120

44
bdfr/download_filter.py Normal file
View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
# coding=utf-8
import logging
import re
logger = logging.getLogger(__name__)
class DownloadFilter:
def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None):
self.excluded_extensions = excluded_extensions
self.excluded_domains = excluded_domains
def check_url(self, url: str) -> bool:
"""Return whether a URL is allowed or not"""
if not self._check_extension(url):
return False
elif not self._check_domain(url):
return False
else:
return True
def _check_extension(self, url: str) -> bool:
if not self.excluded_extensions:
return True
combined_extensions = '|'.join(self.excluded_extensions)
pattern = re.compile(r'.*({})$'.format(combined_extensions))
if re.match(pattern, url):
logger.log(9, f'Url "{url}" matched with "{str(pattern)}"')
return False
else:
return True
def _check_domain(self, url: str) -> bool:
if not self.excluded_domains:
return True
combined_domains = '|'.join(self.excluded_domains)
pattern = re.compile(r'https?://.*({}).*'.format(combined_domains))
if re.match(pattern, url):
logger.log(9, f'Url "{url}" matched with "{str(pattern)}"')
return False
else:
return True

437
bdfr/downloader.py Normal file
View File

@@ -0,0 +1,437 @@
#!/usr/bin/env python3
# coding=utf-8
import configparser
import hashlib
import importlib.resources
import logging
import logging.handlers
import os
import re
import shutil
import socket
from datetime import datetime
from enum import Enum, auto
from multiprocessing import Pool
from pathlib import Path
from typing import Iterator
import appdirs
import praw
import praw.exceptions
import praw.models
import prawcore
import bdfr.exceptions as errors
from bdfr.configuration import Configuration
from bdfr.download_filter import DownloadFilter
from bdfr.file_name_formatter import FileNameFormatter
from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.download_factory import DownloadFactory
logger = logging.getLogger(__name__)
def _calc_hash(existing_file: Path):
with open(existing_file, 'rb') as file:
file_hash = hashlib.md5(file.read()).hexdigest()
return existing_file, file_hash
class RedditTypes:
class SortType(Enum):
HOT = auto()
RISING = auto()
CONTROVERSIAL = auto()
NEW = auto()
RELEVENCE = auto()
class TimeType(Enum):
HOUR = auto()
DAY = auto()
WEEK = auto()
MONTH = auto()
YEAR = auto()
ALL = auto()
class RedditDownloader:
def __init__(self, args: Configuration):
self.args = args
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR')
self.run_time = datetime.now().isoformat()
self._setup_internal_objects()
self.reddit_lists = self._retrieve_reddit_lists()
def _setup_internal_objects(self):
self._determine_directories()
self._load_config()
self._create_file_logger()
self._read_config()
self.download_filter = self._create_download_filter()
logger.log(9, 'Created download filter')
self.time_filter = self._create_time_filter()
logger.log(9, 'Created time filter')
self.sort_filter = self._create_sort_filter()
logger.log(9, 'Created sort filter')
self.file_name_formatter = self._create_file_name_formatter()
logger.log(9, 'Create file name formatter')
self._create_reddit_instance()
self._resolve_user_name()
self.excluded_submission_ids = self._read_excluded_ids()
if self.args.search_existing:
self.master_hash_list = self.scan_existing_files(self.download_directory)
else:
self.master_hash_list = {}
self.authenticator = self._create_authenticator()
logger.log(9, 'Created site authenticator')
def _read_config(self):
"""Read any cfg values that need to be processed"""
if self.args.max_wait_time is None:
if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'):
self.cfg_parser.set('DEFAULT', 'max_wait_time', '120')
logger.log(9, 'Wrote default download wait time download to config file')
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time')
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds')
# Update config on disk
with open(self.config_location, 'w') as file:
self.cfg_parser.write(file)
def _create_reddit_instance(self):
if self.args.authenticate:
logger.debug('Using authenticated Reddit instance')
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
logger.log(9, 'Commencing OAuth2 authentication')
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
scopes = OAuth2Authenticator.split_scopes(scopes)
oauth2_authenticator = OAuth2Authenticator(
scopes,
self.cfg_parser.get('DEFAULT', 'client_id'),
self.cfg_parser.get('DEFAULT', 'client_secret'),
)
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token
with open(self.config_location, 'w') as file:
self.cfg_parser.write(file, True)
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
self.authenticated = True
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
token_manager=token_manager,
)
else:
logger.debug('Using unauthenticated Reddit instance')
self.authenticated = False
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
)
def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
master_list = []
master_list.extend(self._get_subreddits())
logger.log(9, 'Retrieved subreddits')
master_list.extend(self._get_multireddits())
logger.log(9, 'Retrieved multireddits')
master_list.extend(self._get_user_data())
logger.log(9, 'Retrieved user data')
master_list.extend(self._get_submissions_from_link())
logger.log(9, 'Retrieved submissions for given links')
return master_list
def _determine_directories(self):
self.download_directory = Path(self.args.directory).resolve().expanduser()
self.config_directory = Path(self.config_directories.user_config_dir)
self.download_directory.mkdir(exist_ok=True, parents=True)
self.config_directory.mkdir(exist_ok=True, parents=True)
def _load_config(self):
self.cfg_parser = configparser.ConfigParser()
if self.args.config:
if (cfg_path := Path(self.args.config)).exists():
self.cfg_parser.read(cfg_path)
self.config_location = cfg_path
return
possible_paths = [
Path('./config.cfg'),
Path('./default_config.cfg'),
Path(self.config_directory, 'config.cfg'),
Path(self.config_directory, 'default_config.cfg'),
]
self.config_location = None
for path in possible_paths:
if path.resolve().expanduser().exists():
self.config_location = path
logger.debug(f'Loading configuration from {path}')
break
if not self.config_location:
self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0]
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg'))
if not self.config_location:
raise errors.BulkDownloaderException('Could not find a configuration file to load')
self.cfg_parser.read(self.config_location)
def _create_file_logger(self):
main_logger = logging.getLogger()
log_path = Path(self.config_directory, 'log_output.txt')
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3)
file_handler = logging.handlers.RotatingFileHandler(
log_path,
mode='a',
backupCount=backup_count,
)
if log_path.exists():
file_handler.doRollover()
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
file_handler.setFormatter(formatter)
file_handler.setLevel(0)
main_logger.addHandler(file_handler)
@staticmethod
def _sanitise_subreddit_name(subreddit: str) -> str:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
match = re.match(pattern, subreddit)
if not match:
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
return match.group(1)
@staticmethod
def _split_args_input(subreddit_entries: list[str]) -> set[str]:
all_subreddits = []
split_pattern = re.compile(r'[,;]\s?')
for entry in subreddit_entries:
results = re.split(split_pattern, entry)
all_subreddits.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results])
return set(all_subreddits)
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
if self.args.subreddit:
out = []
sort_function = self._determine_sort_function()
for reddit in self._split_args_input(self.args.subreddit):
try:
reddit = self.reddit_instance.subreddit(reddit)
if self.args.search:
out.append(
reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
))
logger.debug(
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
else:
out.append(sort_function(reddit, limit=self.args.limit))
logger.debug(f'Added submissions from subreddit {reddit}')
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
return out
else:
return []
def _resolve_user_name(self):
if self.args.user == 'me':
if self.authenticated:
self.args.user = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {self.args.user}')
else:
self.args.user = None
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
if len(sub_id) == 6:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
else:
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
return [supplied_submissions]
def _determine_sort_function(self):
if self.sort_filter is RedditTypes.SortType.NEW:
sort_function = praw.models.Subreddit.new
elif self.sort_filter is RedditTypes.SortType.RISING:
sort_function = praw.models.Subreddit.rising
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
sort_function = praw.models.Subreddit.controversial
else:
sort_function = praw.models.Subreddit.hot
return sort_function
def _get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
out = []
sort_function = self._determine_sort_function()
for multi in self._split_args_input(self.args.multireddit):
try:
multi = self.reddit_instance.multireddit(self.args.user, multi)
if not multi.subreddits:
raise errors.BulkDownloaderException
out.append(sort_function(multi, limit=self.args.limit))
logger.debug(f'Added submissions from multireddit {multi}')
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
return out
else:
return []
def _get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if self.args.user:
if not self._check_user_existence(self.args.user):
logger.error(f'User {self.args.user} does not exist')
return []
generators = []
sort_function = self._determine_sort_function()
if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
generators.append(
sort_function(
self.reddit_instance.redditor(self.args.user).submissions,
limit=self.args.limit,
))
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication')
else:
if self.args.upvoted:
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit))
if self.args.saved:
logger.debug(f'Retrieving saved posts of user {self.args.user}')
generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit))
return generators
else:
logger.warning('A user must be supplied to download user data')
return []
else:
return []
def _check_user_existence(self, name: str) -> bool:
user = self.reddit_instance.redditor(name=name)
try:
if not user.id:
return False
except prawcore.exceptions.NotFound:
return False
return True
def _create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme)
def _create_time_filter(self) -> RedditTypes.TimeType:
try:
return RedditTypes.TimeType[self.args.time.upper()]
except (KeyError, AttributeError):
return RedditTypes.TimeType.ALL
def _create_sort_filter(self) -> RedditTypes.SortType:
try:
return RedditTypes.SortType[self.args.sort.upper()]
except (KeyError, AttributeError):
return RedditTypes.SortType.HOT
def _create_download_filter(self) -> DownloadFilter:
return DownloadFilter(self.args.skip, self.args.skip_domain)
def _create_authenticator(self) -> SiteAuthenticator:
return SiteAuthenticator(self.cfg_parser)
def download(self):
for generator in self.reddit_lists:
for submission in generator:
if submission.id in self.excluded_submission_ids:
logger.debug(f'Submission {submission.id} in exclusion list, skipping')
continue
else:
logger.debug(f'Attempting to download submission {submission.id}')
self._download_submission(submission)
def _download_submission(self, submission: praw.models.Submission):
if not isinstance(submission, praw.models.Submission):
logger.warning(f'{submission.id} is not a submission')
return
if not self.download_filter.check_url(submission.url):
logger.debug(f'Download filter removed submission {submission.id} with URL {submission.url}')
return
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
logger.debug(f'Using {downloader_class.__name__} with url {submission.url}')
except errors.NotADownloadableLinkError as e:
logger.error(f'Could not download submission {submission.id}: {e}')
return
try:
content = downloader.find_resources(self.authenticator)
except errors.SiteDownloaderError as e:
logger.error(f'Site {downloader_class.__name__} failed to download submission {submission.id}: {e}')
return
for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory):
if destination.exists():
logger.debug(f'File {destination} already exists, continuing')
else:
try:
res.download(self.args.max_wait_time)
except errors.BulkDownloaderException as e:
logger.error(
f'Failed to download resource {res.url} with downloader {downloader_class.__name__}: {e}')
return
resource_hash = res.hash.hexdigest()
destination.parent.mkdir(parents=True, exist_ok=True)
if resource_hash in self.master_hash_list:
if self.args.no_dupes:
logger.info(
f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere')
return
elif self.args.make_hard_links:
self.master_hash_list[resource_hash].link_to(destination)
logger.info(
f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}')
return
with open(destination, 'wb') as file:
file.write(res.content)
logger.debug(f'Written file to {destination}')
self.master_hash_list[resource_hash] = destination
logger.debug(f'Hash added to master list: {resource_hash}')
logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}')
@staticmethod
def scan_existing_files(directory: Path) -> dict[str, Path]:
files = []
for (dirpath, dirnames, filenames) in os.walk(directory):
files.extend([Path(dirpath, file) for file in filenames])
logger.info(f'Calculating hashes for {len(files)} files')
pool = Pool(15)
results = pool.map(_calc_hash, files)
pool.close()
hash_list = {res[1]: res[0] for res in results}
return hash_list
def _read_excluded_ids(self) -> set[str]:
out = []
out.extend(self.args.exclude_id)
for id_file in self.args.exclude_id_file:
id_file = Path(id_file).resolve().expanduser()
if not id_file.exists():
logger.warning(f'ID exclusion file at {id_file} does not exist')
continue
with open(id_file, 'r') as file:
for line in file:
out.append(line.strip())
return set(out)

28
bdfr/exceptions.py Normal file
View File

@@ -0,0 +1,28 @@
#!/usr/bin/env
class BulkDownloaderException(Exception):
pass
class RedditUserError(BulkDownloaderException):
pass
class RedditAuthenticationError(RedditUserError):
pass
class ArchiverError(BulkDownloaderException):
pass
class SiteDownloaderError(BulkDownloaderException):
pass
class NotADownloadableLinkError(SiteDownloaderError):
pass
class ResourceNotFound(SiteDownloaderError):
pass

158
bdfr/file_name_formatter.py Normal file
View File

@@ -0,0 +1,158 @@
#!/usr/bin/env python3
# coding=utf-8
import logging
import platform
import re
from pathlib import Path
from typing import Optional
from praw.models import Comment, Submission
from bdfr.exceptions import BulkDownloaderException
from bdfr.resource import Resource
logger = logging.getLogger(__name__)
class FileNameFormatter:
key_terms = (
'date',
'flair',
'postid',
'redditor',
'subreddit',
'title',
'upvotes',
)
def __init__(self, file_format_string: str, directory_format_string: str):
if not self.validate_string(file_format_string):
raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string')
self.file_format_string = file_format_string
self.directory_format_string: list[str] = directory_format_string.split('/')
@staticmethod
def _format_name(submission: (Comment, Submission), format_string: str) -> str:
if isinstance(submission, Submission):
attributes = FileNameFormatter._generate_name_dict_from_submission(submission)
elif isinstance(submission, Comment):
attributes = FileNameFormatter._generate_name_dict_from_comment(submission)
else:
raise BulkDownloaderException(f'Cannot name object {type(submission).__name__}')
result = format_string
for key in attributes.keys():
if re.search(fr'(?i).*{{{key}}}.*', result):
key_value = attributes.get(key, 'unknown')
key_value = bytes(key_value, 'utf-8').decode('unicode-escape')
result = re.sub(fr'(?i){{{key}}}', key_value, result,)
logger.log(9, f'Found key string {key} in name')
result = result.replace('/', '')
if platform.system() == 'Windows':
result = FileNameFormatter._format_for_windows(result)
return result
@staticmethod
def _generate_name_dict_from_submission(submission: Submission) -> dict:
submission_attributes = {
'title': submission.title,
'subreddit': submission.subreddit.display_name,
'redditor': submission.author.name if submission.author else 'DELETED',
'postid': submission.id,
'upvotes': submission.score,
'flair': submission.link_flair_text,
'date': submission.created_utc
}
return submission_attributes
@staticmethod
def _generate_name_dict_from_comment(comment: Comment) -> dict:
comment_attributes = {
'title': comment.submission.title,
'subreddit': comment.subreddit.display_name,
'redditor': comment.author.name if comment.author else 'DELETED',
'postid': comment.id,
'upvotes': comment.score,
'flair': '',
'date': comment.created_utc,
}
return comment_attributes
def format_path(
self,
resource: Resource,
destination_directory: Path,
index: Optional[int] = None,
) -> Path:
subfolder = Path(
destination_directory,
*[self._format_name(resource.source_submission, part) for part in self.directory_format_string]
)
index = f'_{str(index)}' if index else ''
if not resource.extension:
raise BulkDownloaderException(f'Resource from {resource.url} has no extension')
ending = index + resource.extension
file_name = str(self._format_name(resource.source_submission, self.file_format_string))
file_name = self._limit_file_name_length(file_name, ending)
try:
file_path = Path(subfolder, file_name)
except TypeError:
raise BulkDownloaderException(f'Could not determine path name: {subfolder}, {index}, {resource.extension}')
return file_path
@staticmethod
def _limit_file_name_length(filename: str, ending: str) -> str:
possible_id = re.search(r'((?:_\w{6})?$)', filename)
if possible_id:
ending = possible_id.group(1) + ending
filename = filename[:possible_id.start()]
max_length_chars = 255 - len(ending)
max_length_bytes = 255 - len(ending.encode('utf-8'))
while len(filename) > max_length_chars or len(filename.encode('utf-8')) > max_length_bytes:
filename = filename[:-1]
return filename + ending
def format_resource_paths(
self,
resources: list[Resource],
destination_directory: Path,
) -> list[tuple[Path, Resource]]:
out = []
if len(resources) == 1:
out.append((self.format_path(resources[0], destination_directory, None), resources[0]))
else:
for i, res in enumerate(resources, start=1):
logger.log(9, f'Formatting filename with index {i}')
out.append((self.format_path(res, destination_directory, i), res))
return out
@staticmethod
def validate_string(test_string: str) -> bool:
if not test_string:
return False
result = any([f'{{{key}}}' in test_string.lower() for key in FileNameFormatter.key_terms])
if result:
if 'POSTID' not in test_string:
logger.warning(
'Some files might not be downloaded due to name conflicts as filenames are'
' not guaranteed to be be unique without {POSTID}')
return True
else:
return False
@staticmethod
def _format_for_windows(input_string: str) -> str:
invalid_characters = r'<>:"\/|?*'
for char in invalid_characters:
input_string = input_string.replace(char, '')
input_string = FileNameFormatter._strip_emojis(input_string)
return input_string
@staticmethod
def _strip_emojis(input_string: str) -> str:
result = input_string.encode('ascii', errors='ignore').decode('utf-8')
return result

107
bdfr/oauth2.py Normal file
View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# coding=utf-8
import configparser
import logging
import random
import re
import socket
from pathlib import Path
import praw
import requests
from bdfr.exceptions import BulkDownloaderException, RedditAuthenticationError
logger = logging.getLogger(__name__)
class OAuth2Authenticator:
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
self._check_scopes(wanted_scopes)
self.scopes = wanted_scopes
self.client_id = client_id
self.client_secret = client_secret
@staticmethod
def _check_scopes(wanted_scopes: set[str]):
response = requests.get('https://www.reddit.com/api/v1/scopes.json',
headers={'User-Agent': 'fetch-scopes test'})
known_scopes = [scope for scope, data in response.json().items()]
known_scopes.append('*')
for scope in wanted_scopes:
if scope not in known_scopes:
raise BulkDownloaderException(f'Scope {scope} is not known to reddit')
@staticmethod
def split_scopes(scopes: str) -> set[str]:
scopes = re.split(r'[,: ]+', scopes)
return set(scopes)
def retrieve_new_token(self) -> str:
reddit = praw.Reddit(
redirect_uri='http://localhost:7634',
user_agent='obtain_refresh_token for BDFR',
client_id=self.client_id,
client_secret=self.client_secret)
state = str(random.randint(0, 65000))
url = reddit.auth.url(self.scopes, state, 'permanent')
logger.warning('Authentication action required before the program can proceed')
logger.warning(f'Authenticate at {url}')
client = self.receive_connection()
data = client.recv(1024).decode('utf-8')
param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&')
params = {key: value for (key, value) in [token.split('=') for token in param_tokens]}
if state != params['state']:
self.send_message(client)
raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
elif 'error' in params:
self.send_message(client)
raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
self.send_message(client, "<script>alert('You can go back to terminal window now.')</script>")
refresh_token = reddit.auth.authorize(params["code"])
return refresh_token
@staticmethod
def receive_connection() -> socket.socket:
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(('localhost', 7634))
logger.log(9, 'Server listening on localhost:7634')
server.listen(1)
client = server.accept()[0]
server.close()
logger.log(9, 'Server closed')
return client
@staticmethod
def send_message(client: socket.socket, message: str):
client.send(f'HTTP/1.1 200 OK\r\n\r\n{message}'.encode('utf-8'))
client.close()
class OAuth2TokenManager(praw.reddit.BaseTokenManager):
def __init__(self, config: configparser.ConfigParser, config_location: Path):
super(OAuth2TokenManager, self).__init__()
self.config = config
self.config_location = config_location
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
if authorizer.refresh_token is None:
if self.config.has_option('DEFAULT', 'user_token'):
authorizer.refresh_token = self.config.get('DEFAULT', 'user_token')
logger.log(9, 'Loaded OAuth2 token for authoriser')
else:
raise RedditAuthenticationError('No auth token loaded in configuration')
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
self.config.set('DEFAULT', 'user_token', authorizer.refresh_token)
with open(self.config_location, 'w') as file:
self.config.write(file, True)
logger.log(9, f'Written OAuth2 token from authoriser to {self.config_location}')

81
bdfr/resource.py Normal file
View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# coding=utf-8
import hashlib
import logging
import re
import time
from typing import Optional
import _hashlib
import requests
from praw.models import Submission
from bdfr.exceptions import BulkDownloaderException
logger = logging.getLogger(__name__)
class Resource:
def __init__(self, source_submission: Submission, url: str, extension: str = None):
self.source_submission = source_submission
self.content: Optional[bytes] = None
self.url = url
self.hash: Optional[_hashlib.HASH] = None
self.extension = extension
if not self.extension:
self.extension = self._determine_extension()
@staticmethod
def retry_download(url: str, max_wait_time: int) -> Optional[bytes]:
wait_time = 60
try:
response = requests.get(url)
if response.status_code == 200:
return response.content
elif response.status_code in (
301,
400,
401,
403,
404,
407,
410,
500,
501,
502,
):
raise BulkDownloaderException(
f'Unrecoverable error requesting resource: HTTP Code {response.status_code}')
else:
raise requests.exceptions.ConnectionError(f'Response code {response.status_code}')
except requests.exceptions.ConnectionError as e:
logger.warning(f'Error occured downloading from {url}, waiting {wait_time} seconds: {e}')
time.sleep(wait_time)
if wait_time < max_wait_time:
return Resource.retry_download(url, max_wait_time)
else:
logger.error(f'Max wait time exceeded for resource at url {url}')
raise
def download(self, max_wait_time: int):
if not self.content:
try:
content = self.retry_download(self.url, max_wait_time)
except requests.exceptions.ConnectionError as e:
raise BulkDownloaderException(f'Could not download resource: {e}')
except BulkDownloaderException:
raise
if content:
self.content = content
if not self.hash and self.content:
self.create_hash()
def create_hash(self):
self.hash = hashlib.md5(self.content)
def _determine_extension(self) -> str:
extension_pattern = re.compile(r'.*(\..{3,5})(?:\?.*)?$')
match = re.search(extension_pattern, self.url)
if match:
return match.group(1)

View File

@@ -0,0 +1,9 @@
#!/usr/bin/env python3
# coding=utf-8
import configparser
class SiteAuthenticator:
def __init__(self, cfg: configparser.ConfigParser):
self.imgur_authentication = None

View File

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# coding=utf-8
import logging
from abc import ABC, abstractmethod
from typing import Optional
import requests
from praw.models import Submission
from bdfr.exceptions import ResourceNotFound
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
logger = logging.getLogger(__name__)
class BaseDownloader(ABC):
def __init__(self, post: Submission, typical_extension: Optional[str] = None):
self.post = post
self.typical_extension = typical_extension
@abstractmethod
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
"""Return list of all un-downloaded Resources from submission"""
raise NotImplementedError
@staticmethod
def retrieve_url(url: str, cookies: dict = None, headers: dict = None) -> requests.Response:
res = requests.get(url, cookies=cookies, headers=headers)
if res.status_code != 200:
raise ResourceNotFound(f'Server responded with {res.status_code} to {url}')
return res

View File

@@ -0,0 +1,17 @@
#!/usr/bin/env python3
from typing import Optional
from praw.models import Submission
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.resource import Resource
from bdfr.site_downloaders.base_downloader import BaseDownloader
class Direct(BaseDownloader):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
return [Resource(self.post, self.post.url)]

View File

@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# coding=utf-8
import re
from typing import Type
from bdfr.exceptions import NotADownloadableLinkError
from bdfr.site_downloaders.base_downloader import BaseDownloader
from bdfr.site_downloaders.direct import Direct
from bdfr.site_downloaders.erome import Erome
from bdfr.site_downloaders.gallery import Gallery
from bdfr.site_downloaders.gfycat import Gfycat
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
from bdfr.site_downloaders.imgur import Imgur
from bdfr.site_downloaders.redgifs import Redgifs
from bdfr.site_downloaders.self_post import SelfPost
from bdfr.site_downloaders.vreddit import VReddit
from bdfr.site_downloaders.youtube import Youtube
class DownloadFactory:
@staticmethod
def pull_lever(url: str) -> Type[BaseDownloader]:
url_beginning = r'\s*(https?://(www\.)?)'
if re.match(url_beginning + r'(i\.)?imgur.*\.gifv$', url):
return Imgur
elif re.match(url_beginning + r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', url):
return Direct
elif re.match(url_beginning + r'erome\.com.*', url):
return Erome
elif re.match(url_beginning + r'reddit\.com/gallery/.*', url):
return Gallery
elif re.match(url_beginning + r'gfycat\.', url):
return Gfycat
elif re.match(url_beginning + r'gifdeliverynetwork', url):
return GifDeliveryNetwork
elif re.match(url_beginning + r'(m\.)?imgur.*', url):
return Imgur
elif re.match(url_beginning + r'redgifs.com', url):
return Redgifs
elif re.match(url_beginning + r'reddit\.com/r/', url):
return SelfPost
elif re.match(url_beginning + r'v\.redd\.it', url):
return VReddit
elif re.match(url_beginning + r'(m\.)?youtu\.?be', url):
return Youtube
elif re.match(url_beginning + r'i\.redd\.it.*', url):
return Direct
else:
raise NotADownloadableLinkError(f'No downloader module exists for url {url}')

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python3
import logging
import re
from typing import Optional
import bs4
from praw.models import Submission
from bdfr.exceptions import SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader
logger = logging.getLogger(__name__)
class Erome(BaseDownloader):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
links = self._get_links(self.post.url)
if not links:
raise SiteDownloaderError('Erome parser could not find any links')
out = []
for link in links:
if not re.match(r'https?://.*', link):
link = 'https://' + link
out.append(Resource(self.post, link))
return out
@staticmethod
def _get_links(url: str) -> set[str]:
page = Erome.retrieve_url(url)
soup = bs4.BeautifulSoup(page.text, 'html.parser')
front_images = soup.find_all('img', attrs={'class': 'lasyload'})
out = [im.get('data-src') for im in front_images]
videos = soup.find_all('source')
out.extend([vid.get('src') for vid in videos])
return set(out)

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python3
import logging
import re
from typing import Optional
import bs4
from praw.models import Submission
from bdfr.exceptions import SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader
logger = logging.getLogger(__name__)
class Gallery(BaseDownloader):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
image_urls = self._get_links(self.post.url)
if not image_urls:
raise SiteDownloaderError('No images found in Reddit gallery')
return [Resource(self.post, url) for url in image_urls]
@staticmethod
def _get_links(url: str) -> list[str]:
resource_headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)'
' Chrome/67.0.3396.87 Safari/537.36 OPR/54.0.2952.64',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8',
}
page = Gallery.retrieve_url(url, headers=resource_headers)
soup = bs4.BeautifulSoup(page.text, 'html.parser')
links = soup.findAll('a', attrs={'target': '_blank', 'href': re.compile(r'https://preview\.redd\.it.*')})
links = [link.get('href') for link in links]
return links

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python3
import json
import re
from typing import Optional
from bs4 import BeautifulSoup
from praw.models import Submission
from bdfr.exceptions import SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
class Gfycat(GifDeliveryNetwork):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
return super().find_resources(authenticator)
@staticmethod
def _get_link(url: str) -> str:
gfycat_id = re.match(r'.*/(.*?)/?$', url).group(1)
url = 'https://gfycat.com/' + gfycat_id
response = Gfycat.retrieve_url(url)
if 'gifdeliverynetwork' in response.url:
return GifDeliveryNetwork._get_link(url)
soup = BeautifulSoup(response.text, 'html.parser')
content = soup.find('script', attrs={'data-react-helmet': 'true', 'type': 'application/ld+json'})
try:
out = json.loads(content.contents[0])['video']['contentUrl']
except (IndexError, KeyError) as e:
raise SiteDownloaderError(f'Failed to download Gfycat link {url}: {e}')
except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Did not receive valid JSON data: {e}')
return out

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python3
from typing import Optional
from bs4 import BeautifulSoup
from praw.models import Submission
from bdfr.exceptions import NotADownloadableLinkError, SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader
class GifDeliveryNetwork(BaseDownloader):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
media_url = self._get_link(self.post.url)
return [Resource(self.post, media_url, '.mp4')]
@staticmethod
def _get_link(url: str) -> str:
page = GifDeliveryNetwork.retrieve_url(url)
soup = BeautifulSoup(page.text, 'html.parser')
content = soup.find('source', attrs={'id': 'mp4Source', 'type': 'video/mp4'})
try:
out = content['src']
if not out:
raise KeyError
except KeyError:
raise SiteDownloaderError('Could not find source link')
return out

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python3
import json
import re
from typing import Optional
import bs4
from praw.models import Submission
from bdfr.exceptions import NotADownloadableLinkError, SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader
class Imgur(BaseDownloader):
def __init__(self, post: Submission):
super().__init__(post)
self.raw_data = {}
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
self.raw_data = self._get_data(self.post.url)
out = []
if 'album_images' in self.raw_data:
images = self.raw_data['album_images']
for image in images['images']:
out.append(self._download_image(image))
else:
out.append(self._download_image(self.raw_data))
return out
def _download_image(self, image: dict) -> Resource:
image_url = 'https://i.imgur.com/' + image['hash'] + self._validate_extension(image['ext'])
return Resource(self.post, image_url)
@staticmethod
def _get_data(link: str) -> dict:
if re.match(r'.*\.gifv$', link):
link = link.replace('i.imgur', 'imgur')
link = link.rstrip('.gifv')
res = Imgur.retrieve_url(link, cookies={'over18': '1', 'postpagebeta': '0'})
soup = bs4.BeautifulSoup(res.text, 'html.parser')
scripts = soup.find_all('script', attrs={'type': 'text/javascript'})
scripts = [script.string.replace('\n', '') for script in scripts if script.string]
script_regex = re.compile(r'\s*\(function\(widgetFactory\)\s*{\s*widgetFactory\.mergeConfig\(\'gallery\'')
chosen_script = list(filter(lambda s: re.search(script_regex, s), scripts))
if len(chosen_script) != 1:
raise SiteDownloaderError(f'Could not read page source from {link}')
chosen_script = chosen_script[0]
outer_regex = re.compile(r'widgetFactory\.mergeConfig\(\'gallery\', ({.*})\);')
inner_regex = re.compile(r'image\s*:(.*),\s*group')
try:
image_dict = re.search(outer_regex, chosen_script).group(1)
image_dict = re.search(inner_regex, image_dict).group(1)
except AttributeError:
raise SiteDownloaderError(f'Could not find image dictionary in page source')
try:
image_dict = json.loads(image_dict)
except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Could not parse received dict as JSON: {e}')
return image_dict
@staticmethod
def _validate_extension(extension_suffix: str) -> str:
possible_extensions = ('.jpg', '.png', '.mp4', '.gif')
selection = [ext for ext in possible_extensions if ext == extension_suffix]
if len(selection) == 1:
return selection[0]
else:
raise SiteDownloaderError(f'"{extension_suffix}" is not recognized as a valid extension for Imgur')

View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python3
import json
import re
from typing import Optional
from bs4 import BeautifulSoup
from praw.models import Submission
from bdfr.exceptions import NotADownloadableLinkError, SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
class Redgifs(GifDeliveryNetwork):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
return super().find_resources(authenticator)
@staticmethod
def _get_link(url: str) -> str:
try:
redgif_id = re.match(r'.*/(.*?)/?$', url).group(1)
except AttributeError:
raise SiteDownloaderError(f'Could not extract Redgifs ID from {url}')
url = 'https://redgifs.com/watch/' + redgif_id
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)'
' Chrome/67.0.3396.87 Safari/537.36 OPR/54.0.2952.64',
}
page = Redgifs.retrieve_url(url, headers=headers)
soup = BeautifulSoup(page.text, 'html.parser')
content = soup.find('script', attrs={'data-react-helmet': 'true', 'type': 'application/ld+json'})
if content is None:
raise SiteDownloaderError('Could not read the page source')
try:
out = json.loads(content.contents[0])['video']['contentUrl']
except (IndexError, KeyError):
raise SiteDownloaderError('Failed to find JSON data in page')
except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Received data was not valid JSON: {e}')
return out

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python3
import logging
from typing import Optional
from praw.models import Submission
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader
logger = logging.getLogger(__name__)
class SelfPost(BaseDownloader):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
out = Resource(self.post, self.post.url, '.txt')
out.content = self.export_to_string().encode('utf-8')
out.create_hash()
return [out]
def export_to_string(self) -> str:
"""Self posts are formatted here"""
content = ("## ["
+ self.post.fullname
+ "]("
+ self.post.url
+ ")\n"
+ self.post.selftext
+ "\n\n---\n\n"
+ "submitted to [r/"
+ self.post.subreddit.title
+ "](https://www.reddit.com/r/"
+ self.post.subreddit.title
+ ") by [u/"
+ (self.post.author.name if self.post.author else "DELETED")
+ "](https://www.reddit.com/user/"
+ (self.post.author.name if self.post.author else "DELETED")
+ ")")
return content

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env python3
import logging
from typing import Optional
from praw.models import Submission
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.youtube import Youtube
logger = logging.getLogger(__name__)
class VReddit(Youtube):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
out = super()._download_video({})
return [out]

View File

@@ -0,0 +1,50 @@
#!/usr/bin/env python3
import logging
import tempfile
from pathlib import Path
from typing import Optional
import youtube_dl
from praw.models import Submission
from bdfr.exceptions import SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader
logger = logging.getLogger(__name__)
class Youtube(BaseDownloader):
def __init__(self, post: Submission):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
ytdl_options = {
'format': 'best',
'playlistend': 1,
'nooverwrites': True,
}
out = self._download_video(ytdl_options)
return [out]
def _download_video(self, ytdl_options: dict) -> Resource:
ytdl_options['quiet'] = True
with tempfile.TemporaryDirectory() as temp_dir:
download_path = Path(temp_dir).resolve()
ytdl_options['outtmpl'] = str(download_path) + '/' + 'test.%(ext)s'
try:
with youtube_dl.YoutubeDL(ytdl_options) as ydl:
ydl.download([self.post.url])
except youtube_dl.DownloadError as e:
raise SiteDownloaderError(f'Youtube download failed: {e}')
downloaded_file = list(download_path.iterdir())[0]
extension = downloaded_file.suffix
with open(downloaded_file, 'rb') as file:
content = file.read()
out = Resource(self.post, self.post.url, extension)
out.content = content
out.create_hash()
return out

0
bdfr/tests/__init__.py Normal file
View File

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# coding=utf-8

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# coding=utf-8
import praw
import pytest
from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_comment_id', 'expected_dict'), (
('gstd4hk', {
'author': 'james_pic',
'subreddit': 'Python',
'submission': 'mgi4op',
'submission_title': '76% Faster CPython',
}),
))
def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_instance: praw.Reddit):
comment = reddit_instance.comment(id=test_comment_id)
test_entry = CommentArchiveEntry(comment)
result = test_entry.compile()
assert all([result.get(key) == expected_dict[key] for key in expected_dict.keys()])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_comment_id', 'expected_min_comments'), (
('gstd4hk', 4),
('gsvyste', 3),
('gsxnvvb', 5),
))
def test_get_comment_replies(test_comment_id: str, expected_min_comments: int, reddit_instance: praw.Reddit):
comment = reddit_instance.comment(id=test_comment_id)
test_entry = CommentArchiveEntry(comment)
result = test_entry.compile()
assert len(result.get('replies')) >= expected_min_comments

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# coding=utf-8
import praw
import pytest
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'min_comments'), (
('m3reby', 27),
))
def test_get_comments(test_submission_id: str, min_comments: int, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id)
test_archive_entry = SubmissionArchiveEntry(test_submission)
results = test_archive_entry._get_comments()
assert len(results) >= min_comments
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_dict'), (
('m3reby', {
'author': 'sinjen-tos',
'id': 'm3reby',
'link_flair_text': 'image',
}),
('m3kua3', {'author': 'DELETED'}),
))
def test_get_post_details(test_submission_id: str, expected_dict: dict, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id)
test_archive_entry = SubmissionArchiveEntry(test_submission)
test_archive_entry._get_post_details()
assert all([test_archive_entry.post_details.get(key) == expected_dict[key] for key in expected_dict.keys()])

34
bdfr/tests/conftest.py Normal file
View File

@@ -0,0 +1,34 @@
#!/usr/bin/env python3
# coding=utf-8
import configparser
import socket
from pathlib import Path
import praw
import pytest
from bdfr.oauth2 import OAuth2TokenManager
@pytest.fixture(scope='session')
def reddit_instance():
rd = praw.Reddit(client_id='U-6gk4ZCh3IeNQ', client_secret='7CZHY6AmKweZME5s50SfDGylaPg', user_agent='test')
return rd
@pytest.fixture(scope='session')
def authenticated_reddit_instance():
test_config_path = Path('test_config.cfg')
if not test_config_path.exists():
pytest.skip('Refresh token must be provided to authenticate with OAuth2')
cfg_parser = configparser.ConfigParser()
cfg_parser.read(test_config_path)
if not cfg_parser.has_option('DEFAULT', 'user_token'):
pytest.skip('Refresh token must be provided to authenticate with OAuth2')
token_manager = OAuth2TokenManager(cfg_parser, test_config_path)
reddit_instance = praw.Reddit(client_id=cfg_parser.get('DEFAULT', 'client_id'),
client_secret=cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(),
token_manager=token_manager)
return reddit_instance

View File

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python3
# coding=utf-8
from unittest.mock import Mock
import pytest
from bdfr.resource import Resource
from bdfr.site_downloaders.direct import Direct
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
('https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4', '48f9bd4dbec1556d7838885612b13b39'),
('https://giant.gfycat.com/DazzlingSilkyIguana.mp4', '808941b48fc1e28713d36dd7ed9dc648'),
))
def test_download_resource(test_url: str, expected_hash: str):
mock_submission = Mock()
mock_submission.url = test_url
test_site = Direct(mock_submission)
resources = test_site.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
assert resources[0].hash.hexdigest() == expected_hash

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# coding=utf-8
import praw
import pytest
from bdfr.exceptions import NotADownloadableLinkError
from bdfr.site_downloaders.base_downloader import BaseDownloader
from bdfr.site_downloaders.direct import Direct
from bdfr.site_downloaders.download_factory import DownloadFactory
from bdfr.site_downloaders.erome import Erome
from bdfr.site_downloaders.gallery import Gallery
from bdfr.site_downloaders.gfycat import Gfycat
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
from bdfr.site_downloaders.imgur import Imgur
from bdfr.site_downloaders.redgifs import Redgifs
from bdfr.site_downloaders.self_post import SelfPost
from bdfr.site_downloaders.vreddit import VReddit
from bdfr.site_downloaders.youtube import Youtube
@pytest.mark.parametrize(('test_submission_url', 'expected_class'), (
('https://v.redd.it/9z1dnk3xr5k61', VReddit),
('https://www.reddit.com/r/TwoXChromosomes/comments/lu29zn/i_refuse_to_live_my_life'
'_in_anything_but_comfort/', SelfPost),
('https://i.imgur.com/bZx1SJQ.jpg', Direct),
('https://i.redd.it/affyv0axd5k61.png', Direct),
('https://imgur.com/3ls94yv.jpeg', Direct),
('https://i.imgur.com/BuzvZwb.gifv', Imgur),
('https://imgur.com/BuzvZwb.gifv', Imgur),
('https://i.imgur.com/6fNdLst.gif', Direct),
('https://imgur.com/a/MkxAzeg', Imgur),
('https://www.reddit.com/gallery/lu93m7', Gallery),
('https://gfycat.com/concretecheerfulfinwhale', Gfycat),
('https://www.erome.com/a/NWGw0F09', Erome),
('https://youtube.com/watch?v=Gv8Wz74FjVA', Youtube),
('https://redgifs.com/watch/courageousimpeccablecanvasback', Redgifs),
('https://www.gifdeliverynetwork.com/repulsivefinishedandalusianhorse', GifDeliveryNetwork),
('https://youtu.be/DevfjHOhuFc', Youtube),
('https://m.youtube.com/watch?v=kr-FeojxzUM', Youtube),
('https://i.imgur.com/3SKrQfK.jpg?1', Direct),
('https://dynasty-scans.com/system/images_images/000/017/819/original/80215103_p0.png?1612232781', Direct),
('https://m.imgur.com/a/py3RW0j', Imgur),
))
def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownloader, reddit_instance: praw.Reddit):
result = DownloadFactory.pull_lever(test_submission_url)
assert result is expected_class
@pytest.mark.parametrize('test_url', (
'random.com',
'bad',
'https://www.google.com/',
'https://www.google.com',
'https://www.google.com/test',
'https://www.google.com/test/',
))
def test_factory_lever_bad(test_url: str):
with pytest.raises(NotADownloadableLinkError):
DownloadFactory.pull_lever(test_url)

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# coding=utf-8
from unittest.mock import MagicMock
import pytest
from bdfr.site_downloaders.erome import Erome
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_urls'), (
('https://www.erome.com/a/vqtPuLXh', (
'https://s11.erome.com/365/vqtPuLXh/KH2qBT99_480p.mp4',
)),
('https://www.erome.com/a/ORhX0FZz', (
'https://s4.erome.com/355/ORhX0FZz/9IYQocM9_480p.mp4',
'https://s4.erome.com/355/ORhX0FZz/9eEDc8xm_480p.mp4',
'https://s4.erome.com/355/ORhX0FZz/EvApC7Rp_480p.mp4',
'https://s4.erome.com/355/ORhX0FZz/LruobtMs_480p.mp4',
'https://s4.erome.com/355/ORhX0FZz/TJNmSUU5_480p.mp4',
'https://s4.erome.com/355/ORhX0FZz/X11Skh6Z_480p.mp4',
'https://s4.erome.com/355/ORhX0FZz/bjlTkpn7_480p.mp4'
)),
))
def test_get_link(test_url: str, expected_urls: tuple[str]):
result = Erome. _get_links(test_url)
assert set(result) == set(expected_urls)
@pytest.mark.online
@pytest.mark.slow
@pytest.mark.parametrize(('test_url', 'expected_hashes'), (
('https://www.erome.com/a/vqtPuLXh', {
'5da2a8d60d87bed279431fdec8e7d72f'
}),
('https://www.erome.com/i/ItASD33e', {
'b0d73fedc9ce6995c2f2c4fdb6f11eff'
}),
('https://www.erome.com/a/lGrcFxmb', {
'0e98f9f527a911dcedde4f846bb5b69f',
'25696ae364750a5303fc7d7dc78b35c1',
'63775689f438bd393cde7db6d46187de',
'a1abf398cfd4ef9cfaf093ceb10c746a',
'bd9e1a4ea5ef0d6ba47fb90e337c2d14'
}),
))
def test_download_resource(test_url: str, expected_hashes: tuple[str]):
# Can't compare hashes for this test, Erome doesn't return the exact same file from request to request so the hash
# will change back and forth randomly
mock_submission = MagicMock()
mock_submission.url = test_url
test_site = Erome(mock_submission)
resources = test_site.find_resources()
[res.download(120) for res in resources]
resource_hashes = [res.hash.hexdigest() for res in resources]
assert len(resource_hashes) == len(expected_hashes)

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# coding=utf-8
import praw
import pytest
from bdfr.site_downloaders.gallery import Gallery
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected'), (
('https://www.reddit.com/gallery/m6lvrh', {
'https://preview.redd.it/18nzv9ch0hn61.jpg?width=4160&'
'format=pjpg&auto=webp&s=470a825b9c364e0eace0036882dcff926f821de8',
'https://preview.redd.it/jqkizcch0hn61.jpg?width=4160&'
'format=pjpg&auto=webp&s=ae4f552a18066bb6727676b14f2451c5feecf805',
'https://preview.redd.it/k0fnqzbh0hn61.jpg?width=4160&'
'format=pjpg&auto=webp&s=c6a10fececdc33983487c16ad02219fd3fc6cd76',
'https://preview.redd.it/m3gamzbh0hn61.jpg?width=4160&'
'format=pjpg&auto=webp&s=0dd90f324711851953e24873290b7f29ec73c444'
}),
('https://www.reddit.com/gallery/ljyy27', {
'https://preview.redd.it/04vxj25uqih61.png?width=92&'
'format=png&auto=webp&s=6513f3a5c5128ee7680d402cab5ea4fb2bbeead4',
'https://preview.redd.it/0fnx83kpqih61.png?width=241&'
'format=png&auto=webp&s=655e9deb6f499c9ba1476eaff56787a697e6255a',
'https://preview.redd.it/7zkmr1wqqih61.png?width=237&'
'format=png&auto=webp&s=19de214e634cbcad9959f19570c616e29be0c0b0',
'https://preview.redd.it/u37k5gxrqih61.png?width=443&'
'format=png&auto=webp&s=e74dae31841fe4a2545ffd794d3b25b9ff0eb862'
}),
))
def test_gallery_get_links(test_url: str, expected: set[str]):
results = Gallery._get_links(test_url)
assert set(results) == expected
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_hashes'), (
('m6lvrh', {
'6c8a892ae8066cbe119218bcaac731e1',
'93ce177f8cb7994906795f4615114d13',
'9a293adf19354f14582608cf22124574',
'b73e2c3daee02f99404644ea02f1ae65'
}),
('ljyy27', {
'1bc38bed88f9c4770e22a37122d5c941',
'2539a92b78f3968a069df2dffe2279f9',
'37dea50281c219b905e46edeefc1a18d',
'ec4924cf40549728dcf53dd40bc7a73c'
}),
))
def test_gallery_download(test_submission_id: str, expected_hashes: set[str], reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id)
gallery = Gallery(test_submission)
results = gallery.find_resources()
[res.download(120) for res in results]
hashes = [res.hash.hexdigest() for res in results]
assert set(hashes) == expected_hashes

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# coding=utf-8
from unittest.mock import Mock
import pytest
from bdfr.resource import Resource
from bdfr.site_downloaders.gfycat import Gfycat
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_url'), (
('https://gfycat.com/definitivecaninecrayfish', 'https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4'),
('https://gfycat.com/dazzlingsilkyiguana', 'https://giant.gfycat.com/DazzlingSilkyIguana.mp4'),
('https://gfycat.com/webbedimpurebutterfly', 'https://thumbs2.redgifs.com/WebbedImpureButterfly.mp4'),
))
def test_get_link(test_url: str, expected_url: str):
result = Gfycat._get_link(test_url)
assert result == expected_url
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
('https://gfycat.com/definitivecaninecrayfish', '48f9bd4dbec1556d7838885612b13b39'),
('https://gfycat.com/dazzlingsilkyiguana', '808941b48fc1e28713d36dd7ed9dc648'),
))
def test_download_resource(test_url: str, expected_hash: str):
mock_submission = Mock()
mock_submission.url = test_url
test_site = Gfycat(mock_submission)
resources = test_site.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
assert resources[0].hash.hexdigest() == expected_hash

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
# coding=utf-8
from unittest.mock import Mock
import pytest
from bdfr.resource import Resource
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected'), (
('https://www.gifdeliverynetwork.com/regalshoddyhorsechestnutleafminer',
'https://thumbs2.redgifs.com/RegalShoddyHorsechestnutleafminer.mp4'),
('https://www.gifdeliverynetwork.com/maturenexthippopotamus',
'https://thumbs2.redgifs.com/MatureNextHippopotamus.mp4'),
))
def test_get_link(test_url: str, expected: str):
result = GifDeliveryNetwork._get_link(test_url)
assert result == expected
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
('https://www.gifdeliverynetwork.com/maturenexthippopotamus', '9bec0a9e4163a43781368ed5d70471df'),
('https://www.gifdeliverynetwork.com/regalshoddyhorsechestnutleafminer', '8afb4e2c090a87140230f2352bf8beba'),
))
def test_download_resource(test_url: str, expected_hash: str):
mock_submission = Mock()
mock_submission.url = test_url
test_site = GifDeliveryNetwork(mock_submission)
resources = test_site.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
assert resources[0].hash.hexdigest() == expected_hash

View File

@@ -0,0 +1,135 @@
#!/usr/bin/env python3
# coding=utf-8
from unittest.mock import Mock
import pytest
from bdfr.exceptions import SiteDownloaderError
from bdfr.resource import Resource
from bdfr.site_downloaders.imgur import Imgur
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_gen_dict', 'expected_image_dict'), (
(
'https://imgur.com/a/xWZsDDP',
{'num_images': '1', 'id': 'xWZsDDP', 'hash': 'xWZsDDP'},
[
{'hash': 'ypa8YfS', 'title': '', 'ext': '.png', 'animated': False}
]
),
(
'https://imgur.com/gallery/IjJJdlC',
{'num_images': 1, 'id': 384898055, 'hash': 'IjJJdlC'},
[
{'hash': 'CbbScDt',
'description': 'watch when he gets it',
'ext': '.gif',
'animated': True,
'has_sound': False
}
],
),
(
'https://imgur.com/a/dcc84Gt',
{'num_images': '4', 'id': 'dcc84Gt', 'hash': 'dcc84Gt'},
[
{'hash': 'ylx0Kle', 'ext': '.jpg', 'title': ''},
{'hash': 'TdYfKbK', 'ext': '.jpg', 'title': ''},
{'hash': 'pCxGbe8', 'ext': '.jpg', 'title': ''},
{'hash': 'TSAkikk', 'ext': '.jpg', 'title': ''},
]
),
(
'https://m.imgur.com/a/py3RW0j',
{'num_images': '1', 'id': 'py3RW0j', 'hash': 'py3RW0j', },
[
{'hash': 'K24eQmK', 'has_sound': False, 'ext': '.jpg'}
],
),
))
def test_get_data_album(test_url: str, expected_gen_dict: dict, expected_image_dict: list[dict]):
result = Imgur._get_data(test_url)
assert all([result.get(key) == expected_gen_dict[key] for key in expected_gen_dict.keys()])
# Check if all the keys from the test dict are correct in at least one of the album entries
assert any([all([image.get(key) == image_dict[key] for key in image_dict.keys()])
for image_dict in expected_image_dict for image in result['album_images']['images']])
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_image_dict'), (
(
'https://i.imgur.com/dLk3FGY.gifv',
{'hash': 'dLk3FGY', 'title': '', 'ext': '.mp4', 'animated': True}
),
(
'https://imgur.com/BuzvZwb.gifv',
{
'hash': 'BuzvZwb',
'title': '',
'description': 'Akron Glass Works',
'animated': True,
'mimetype': 'video/mp4'
},
),
))
def test_get_data_gif(test_url: str, expected_image_dict: dict):
result = Imgur._get_data(test_url)
assert all([result.get(key) == expected_image_dict[key] for key in expected_image_dict.keys()])
@pytest.mark.parametrize('test_extension', (
'.gif',
'.png',
'.jpg',
'.mp4'
))
def test_imgur_extension_validation_good(test_extension: str):
result = Imgur._validate_extension(test_extension)
assert result == test_extension
@pytest.mark.parametrize('test_extension', (
'.jpeg',
'bad',
'.avi',
'.test',
'.flac',
))
def test_imgur_extension_validation_bad(test_extension: str):
with pytest.raises(SiteDownloaderError):
Imgur._validate_extension(test_extension)
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hashes'), (
(
'https://imgur.com/a/xWZsDDP',
('f551d6e6b0fef2ce909767338612e31b',)
),
(
'https://imgur.com/gallery/IjJJdlC',
('7227d4312a9779b74302724a0cfa9081',),
),
(
'https://imgur.com/a/dcc84Gt',
(
'cf1158e1de5c3c8993461383b96610cf',
'28d6b791a2daef8aa363bf5a3198535d',
'248ef8f2a6d03eeb2a80d0123dbaf9b6',
'029c475ce01b58fdf1269d8771d33913',
),
),
))
def test_find_resources(test_url: str, expected_hashes: list[str]):
mock_download = Mock()
mock_download.url = test_url
downloader = Imgur(mock_download)
results = downloader.find_resources()
assert all([isinstance(res, Resource) for res in results])
[res.download(120) for res in results]
hashes = set([res.hash.hexdigest() for res in results])
assert len(results) == len(expected_hashes)
assert hashes == set(expected_hashes)

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
# coding=utf-8
from unittest.mock import Mock
import pytest
from bdfr.resource import Resource
from bdfr.site_downloaders.redgifs import Redgifs
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected'), (
('https://redgifs.com/watch/frighteningvictorioussalamander',
'https://thumbs2.redgifs.com/FrighteningVictoriousSalamander.mp4'),
('https://redgifs.com/watch/springgreendecisivetaruca',
'https://thumbs2.redgifs.com/SpringgreenDecisiveTaruca.mp4'),
))
def test_get_link(test_url: str, expected: str):
result = Redgifs._get_link(test_url)
assert result == expected
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
('https://redgifs.com/watch/frighteningvictorioussalamander', '4007c35d9e1f4b67091b5f12cffda00a'),
('https://redgifs.com/watch/springgreendecisivetaruca', '8dac487ac49a1f18cc1b4dabe23f0869'),
))
def test_download_resource(test_url: str, expected_hash: str):
mock_submission = Mock()
mock_submission.url = test_url
test_site = Redgifs(mock_submission)
resources = test_site.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
assert resources[0].hash.hexdigest() == expected_hash

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# coding=utf-8
import praw
import pytest
from bdfr.resource import Resource
from bdfr.site_downloaders.self_post import SelfPost
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_hash'), (
('ltmivt', '7d2c9e4e989e5cf2dca2e55a06b1c4f6'),
('ltoaan', '221606386b614d6780c2585a59bd333f'),
('d3sc8o', 'c1ff2b6bd3f6b91381dcd18dfc4ca35f'),
))
def test_find_resource(test_submission_id: str, expected_hash: str, reddit_instance: praw.Reddit):
submission = reddit_instance.submission(id=test_submission_id)
downloader = SelfPost(submission)
results = downloader.find_resources()
assert len(results) == 1
assert isinstance(results[0], Resource)
assert results[0].hash.hexdigest() == expected_hash

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python3
# coding=utf-8
import praw
import pytest
from bdfr.resource import Resource
from bdfr.site_downloaders.vreddit import VReddit
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_hash'), (
('lu8l8g', '93a15642d2f364ae39f00c6d1be354ff'),
))
def test_find_resources(test_submission_id: str, expected_hash: str, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id)
downloader = VReddit(test_submission)
resources = downloader.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
assert resources[0].hash.hexdigest() == expected_hash

View File

@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# coding=utf-8
from unittest.mock import MagicMock
import pytest
from bdfr.resource import Resource
from bdfr.site_downloaders.youtube import Youtube
@pytest.mark.online
@pytest.mark.slow
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
('https://www.youtube.com/watch?v=uSm2VDgRIUs', '3c79a62898028987f94161e0abccbddf'),
('https://www.youtube.com/watch?v=m-tKnjFwleU', '30314930d853afff8ebc7d8c36a5b833'),
))
def test_find_resources(test_url: str, expected_hash: str):
test_submission = MagicMock()
test_submission.url = test_url
downloader = Youtube(test_submission)
resources = downloader.find_resources()
assert len(resources) == 1
assert isinstance(resources[0], Resource)
resources[0].download(120)
assert resources[0].hash.hexdigest() == expected_hash

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# coding=utf-8
from pathlib import Path
from unittest.mock import MagicMock
import praw
import pytest
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
from bdfr.archiver import Archiver
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', (
'm3reby',
))
def test_write_submission_json(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
archiver_mock = MagicMock()
test_path = Path(tmp_path, 'test.json')
test_submission = reddit_instance.submission(id=test_submission_id)
archiver_mock.file_name_formatter.format_path.return_value = test_path
test_entry = SubmissionArchiveEntry(test_submission)
Archiver._write_entry_json(archiver_mock, test_entry)
archiver_mock._write_content_to_disk.assert_called_once()
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', (
'm3reby',
))
def test_write_submission_xml(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
archiver_mock = MagicMock()
test_path = Path(tmp_path, 'test.xml')
test_submission = reddit_instance.submission(id=test_submission_id)
archiver_mock.file_name_formatter.format_path.return_value = test_path
test_entry = SubmissionArchiveEntry(test_submission)
Archiver._write_entry_xml(archiver_mock, test_entry)
archiver_mock._write_content_to_disk.assert_called_once()
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', (
'm3reby',
))
def test_write_submission_yaml(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
archiver_mock = MagicMock()
archiver_mock.download_directory = tmp_path
test_path = Path(tmp_path, 'test.yaml')
test_submission = reddit_instance.submission(id=test_submission_id)
archiver_mock.file_name_formatter.format_path.return_value = test_path
test_entry = SubmissionArchiveEntry(test_submission)
Archiver._write_entry_yaml(archiver_mock, test_entry)
archiver_mock._write_content_to_disk.assert_called_once()

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# coding=utf-8
from unittest.mock import MagicMock
import pytest
from bdfr.configuration import Configuration
@pytest.mark.parametrize('arg_dict', (
{'directory': 'test_dir'},
{
'directory': 'test_dir',
'no_dupes': True,
},
))
def test_process_click_context(arg_dict: dict):
test_config = Configuration()
test_context = MagicMock()
test_context.params = arg_dict
test_config.process_click_arguments(test_context)
test_config = vars(test_config)
assert all([test_config[arg] == arg_dict[arg] for arg in arg_dict.keys()])

View File

@@ -0,0 +1,58 @@
#!/usr/bin/env python3
# coding=utf-8
import pytest
from bdfr.download_filter import DownloadFilter
@pytest.fixture()
def download_filter() -> DownloadFilter:
return DownloadFilter(['mp4', 'mp3'], ['test.com', 'reddit.com'])
@pytest.mark.parametrize(('test_url', 'expected'), (
('test.mp4', False),
('test.avi', True),
('test.random.mp3', False),
))
def test_filter_extension(test_url: str, expected: bool, download_filter: DownloadFilter):
result = download_filter._check_extension(test_url)
assert result == expected
@pytest.mark.parametrize(('test_url', 'expected'), (
('test.mp4', True),
('http://reddit.com/test.mp4', False),
('http://reddit.com/test.gif', False),
('https://www.example.com/test.mp4', True),
('https://www.example.com/test.png', True),
))
def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadFilter):
result = download_filter._check_domain(test_url)
assert result == expected
@pytest.mark.parametrize(('test_url', 'expected'), (
('test.mp4', False),
('test.gif', True),
('https://www.example.com/test.mp4', False),
('https://www.example.com/test.png', True),
('http://reddit.com/test.mp4', False),
('http://reddit.com/test.gif', False),
))
def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter):
result = download_filter.check_url(test_url)
assert result == expected
@pytest.mark.parametrize('test_url', (
'test.mp3',
'test.mp4',
'http://reddit.com/test.mp4',
't',
))
def test_filter_empty_filter(test_url: str):
download_filter = DownloadFilter()
result = download_filter.check_url(test_url)
assert result is True

View File

@@ -0,0 +1,430 @@
#!/usr/bin/env python3
# coding=utf-8
import re
from pathlib import Path
from typing import Iterator
from unittest.mock import MagicMock
import praw
import praw.models
import pytest
from bdfr.__main__ import setup_logging
from bdfr.configuration import Configuration
from bdfr.download_filter import DownloadFilter
from bdfr.downloader import RedditDownloader, RedditTypes
from bdfr.exceptions import BulkDownloaderException
from bdfr.file_name_formatter import FileNameFormatter
from bdfr.site_authenticator import SiteAuthenticator
@pytest.fixture()
def args() -> Configuration:
args = Configuration()
return args
@pytest.fixture()
def downloader_mock(args: Configuration):
downloader_mock = MagicMock()
downloader_mock.args = args
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
downloader_mock._split_args_input = RedditDownloader._split_args_input
downloader_mock.master_hash_list = {}
return downloader_mock
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
if result_limit is not None:
assert len(results) == result_limit
return results
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
downloader_mock.args.directory = tmp_path / 'test'
downloader_mock.config_directories.user_config_dir = tmp_path
RedditDownloader._determine_directories(downloader_mock)
assert Path(tmp_path / 'test').exists()
@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
([], []),
(['.test'], ['test.com'],),
))
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
downloader_mock.args.skip = skip_extensions
downloader_mock.args.skip_domain = skip_domains
result = RedditDownloader._create_download_filter(downloader_mock)
assert isinstance(result, DownloadFilter)
assert result.excluded_domains == skip_domains
assert result.excluded_extensions == skip_extensions
@pytest.mark.parametrize(('test_time', 'expected'), (
('all', 'all'),
('hour', 'hour'),
('day', 'day'),
('week', 'week'),
('random', 'all'),
('', 'all'),
))
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.time = test_time
result = RedditDownloader._create_time_filter(downloader_mock)
assert isinstance(result, RedditTypes.TimeType)
assert result.name.lower() == expected
@pytest.mark.parametrize(('test_sort', 'expected'), (
('', 'hot'),
('hot', 'hot'),
('controversial', 'controversial'),
('new', 'new'),
))
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.sort = test_sort
result = RedditDownloader._create_sort_filter(downloader_mock)
assert isinstance(result, RedditTypes.SortType)
assert result.name.lower() == expected
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
('{POSTID}', '{SUBREDDIT}'),
('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
('{POSTID}', 'test'),
('{POSTID}', ''),
('{POSTID}', '{SUBREDDIT}/{REDDITOR}'),
))
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme
result = RedditDownloader._create_file_name_formatter(downloader_mock)
assert isinstance(result, FileNameFormatter)
assert result.file_format_string == test_file_scheme
assert result.directory_format_string == test_folder_scheme.split('/')
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
('', ''),
('', '{SUBREDDIT}'),
('test', '{SUBREDDIT}'),
))
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme
with pytest.raises(BulkDownloaderException):
RedditDownloader._create_file_name_formatter(downloader_mock)
def test_create_authenticator(downloader_mock: MagicMock):
result = RedditDownloader._create_authenticator(downloader_mock)
assert isinstance(result, SiteAuthenticator)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_ids', (
('lvpf4l',),
('lvpf4l', 'lvqnsn'),
('lvpf4l', 'lvqnsn', 'lvl9kd'),
))
def test_get_submissions_from_link(
test_submission_ids: list[str],
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock.args.link = test_submission_ids
downloader_mock.reddit_instance = reddit_instance
results = RedditDownloader._get_submissions_from_link(downloader_mock)
assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
assert len(results[0]) == len(test_submission_ids)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
(('Futurology',), 10),
(('Futurology', 'Mindustry, Python'), 10),
(('Futurology',), 20),
(('Futurology', 'Python'), 10),
(('Futurology',), 100),
(('Futurology',), 0),
))
def test_get_subreddit_normal(
test_subreddits: list[str],
limit: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock)
test_subreddits = downloader_mock._split_args_input(test_subreddits)
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name in test_subreddits for res in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
(('Python',), 'scraper', 10),
(('Python',), '', 10),
(('Python',), 'djsdsgewef', 0),
))
def test_get_subreddit_search(
test_subreddits: list[str],
search_term: str,
limit: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditDownloader._get_subreddits(downloader_mock)
results = assert_all_results_are_submissions(
(limit * len(test_subreddits)) if limit else None, results)
assert all([res.subreddit.display_name in test_subreddits for res in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
('helen_darten', ('cuteanimalpics',), 10),
('korfor', ('chess',), 100),
))
# Good sources at https://www.reddit.com/r/multihub/
def test_get_multireddits_public(
test_user: str,
test_multireddits: list[str],
limit: int,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = test_user
downloader_mock.reddit_instance = reddit_instance
results = RedditDownloader._get_multireddits(downloader_mock)
assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_user', 'limit'), (
('danigirl3694', 10),
('danigirl3694', 50),
('CapitanHam', None),
))
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.limit = limit
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.submitted = True
downloader_mock.args.user = test_user
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
results = RedditDownloader._get_user_data(downloader_mock)
results = assert_all_results_are_submissions(limit, results)
assert all([res.author.name == test_user for res in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
@pytest.mark.parametrize('test_flag', (
'upvoted',
'saved',
))
def test_get_user_authenticated_lists(
test_flag: str,
downloader_mock: MagicMock,
authenticated_reddit_instance: praw.Reddit,
):
downloader_mock.args.__dict__[test_flag] = True
downloader_mock.reddit_instance = authenticated_reddit_instance
downloader_mock.args.user = 'me'
downloader_mock.args.limit = 10
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
RedditDownloader._resolve_user_name(downloader_mock)
results = RedditDownloader._get_user_data(downloader_mock)
assert_all_results_are_submissions(10, results)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), (
('ljyy27', 4),
))
def test_download_submission(
test_submission_id: str,
expected_files_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
tmp_path: Path):
downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = ''
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
RedditDownloader._download_submission(downloader_mock, submission)
folder_contents = list(tmp_path.iterdir())
assert len(folder_contents) == expected_files_len
@pytest.mark.online
@pytest.mark.reddit
def test_download_submission_file_exists(
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
tmp_path: Path,
capsys: pytest.CaptureFixture
):
setup_logging(3)
downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = ''
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path
submission = downloader_mock.reddit_instance.submission(id='m1hqw6')
Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch()
RedditDownloader._download_submission(downloader_mock, submission)
folder_contents = list(tmp_path.iterdir())
output = capsys.readouterr()
assert len(folder_contents) == 1
assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png already exists' in output.out
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'test_hash'), (
('m1hqw6', 'a912af8905ae468e0121e9940f797ad7'),
))
def test_download_submission_hash_exists(
test_submission_id: str,
test_hash: str,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
tmp_path: Path,
capsys: pytest.CaptureFixture
):
setup_logging(3)
downloader_mock.reddit_instance = reddit_instance
downloader_mock.download_filter.check_url.return_value = True
downloader_mock.args.folder_scheme = ''
downloader_mock.args.no_dupes = True
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
downloader_mock.download_directory = tmp_path
downloader_mock.master_hash_list = {test_hash: None}
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
RedditDownloader._download_submission(downloader_mock, submission)
folder_contents = list(tmp_path.iterdir())
output = capsys.readouterr()
assert len(folder_contents) == 0
assert re.search(r'Resource hash .*? downloaded elsewhere', output.out)
@pytest.mark.parametrize(('test_name', 'expected'), (
('Mindustry', 'Mindustry'),
('Futurology', 'Futurology'),
('r/Mindustry', 'Mindustry'),
('TrollXChromosomes', 'TrollXChromosomes'),
('r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
('https://www.reddit.com/r/Futurology/', 'Futurology'),
('https://www.reddit.com/r/Futurology', 'Futurology'),
))
def test_sanitise_subreddit_name(test_name: str, expected: str):
result = RedditDownloader._sanitise_subreddit_name(test_name)
assert result == expected
def test_search_existing_files():
results = RedditDownloader.scan_existing_files(Path('.'))
assert len(results.keys()) >= 40
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'})
))
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
results = RedditDownloader._split_args_input(test_subreddit_entries)
assert results == expected
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize('test_submission_id', (
'm1hqw6',
))
def test_mark_hard_link(
test_submission_id: str,
downloader_mock: MagicMock,
tmp_path: Path,
reddit_instance: praw.Reddit
):
downloader_mock.reddit_instance = reddit_instance
downloader_mock.args.make_hard_links = True
downloader_mock.download_directory = tmp_path
downloader_mock.args.folder_scheme = ''
downloader_mock.args.file_scheme = '{POSTID}'
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
original = Path(tmp_path, f'{test_submission_id}.png')
RedditDownloader._download_submission(downloader_mock, submission)
assert original.exists()
downloader_mock.args.file_scheme = 'test2_{POSTID}'
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
RedditDownloader._download_submission(downloader_mock, submission)
test_file_1_stats = original.stat()
test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino
assert test_file_1_stats.st_nlink == 2
assert test_file_1_stats.st_ino == test_file_2_inode
@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), (
(('aaaaaa',), (), 1),
(('aaaaaa',), ('aaaaaa',), 0),
((), ('aaaaaa',), 0),
(('aaaaaa', 'bbbbbb'), ('aaaaaa',), 1),
))
def test_excluded_ids(test_ids: tuple[str], test_excluded: tuple[str], expected_len: int, downloader_mock: MagicMock):
downloader_mock.excluded_submission_ids = test_excluded
test_submissions = []
for test_id in test_ids:
m = MagicMock()
m.id = test_id
test_submissions.append(m)
downloader_mock.reddit_lists = [test_submissions]
RedditDownloader.download(downloader_mock)
assert downloader_mock._download_submission.call_count == expected_len
def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
test_file = tmp_path / 'test.txt'
test_file.write_text('aaaaaa\nbbbbbb')
downloader_mock.args.exclude_id_file = [test_file]
results = RedditDownloader._read_excluded_ids(downloader_mock)
assert results == {'aaaaaa', 'bbbbbb'}

View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
# coding=utf-8
from pathlib import Path
from typing import Optional
from unittest.mock import MagicMock
import praw.models
import pytest
from bdfr.file_name_formatter import FileNameFormatter
from bdfr.resource import Resource
@pytest.fixture()
def submission() -> MagicMock:
test = MagicMock()
test.title = 'name'
test.subreddit.display_name = 'randomreddit'
test.author.name = 'person'
test.id = '12345'
test.score = 1000
test.link_flair_text = 'test_flair'
test.created_utc = 123456789
test.__class__ = praw.models.Submission
return test
@pytest.fixture(scope='session')
def reddit_submission(reddit_instance: praw.Reddit) -> praw.models.Submission:
return reddit_instance.submission(id='lgilgt')
@pytest.mark.parametrize(('format_string', 'expected'), (
('{SUBREDDIT}', 'randomreddit'),
('{REDDITOR}', 'person'),
('{POSTID}', '12345'),
('{UPVOTES}', '1000'),
('{FLAIR}', 'test_flair'),
('{DATE}', '123456789'),
('{REDDITOR}_{TITLE}_{POSTID}', 'person_name_12345'),
('{RANDOM}', '{RANDOM}'),
))
def test_format_name_mock(format_string: str, expected: str, submission: MagicMock):
result = FileNameFormatter._format_name(submission, format_string)
assert result == expected
@pytest.mark.parametrize(('test_string', 'expected'), (
('', False),
('test', False),
('{POSTID}', True),
('POSTID', False),
('{POSTID}_test', True),
('test_{TITLE}', True),
('TITLE_POSTID', False),
))
def test_check_format_string_validity(test_string: str, expected: bool):
result = FileNameFormatter.validate_string(test_string)
assert result == expected
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('format_string', 'expected'), (
('{SUBREDDIT}', 'Mindustry'),
('{REDDITOR}', 'Gamer_player_boi'),
('{POSTID}', 'lgilgt'),
('{FLAIR}', 'Art'),
('{SUBREDDIT}_{TITLE}', 'Mindustry_Toxopid that is NOT humane >:('),
('{REDDITOR}_{TITLE}_{POSTID}', 'Gamer_player_boi_Toxopid that is NOT humane >:(_lgilgt')
))
def test_format_name_real(format_string: str, expected: str, reddit_submission: praw.models.Submission):
result = FileNameFormatter._format_name(reddit_submission, format_string)
assert result == expected
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('format_string_directory', 'format_string_file', 'expected'), (
(
'{SUBREDDIT}',
'{POSTID}',
'test/Mindustry/lgilgt.png',
),
(
'{SUBREDDIT}',
'{TITLE}_{POSTID}',
'test/Mindustry/Toxopid that is NOT humane >:(_lgilgt.png',
),
(
'{SUBREDDIT}',
'{REDDITOR}_{TITLE}_{POSTID}',
'test/Mindustry/Gamer_player_boi_Toxopid that is NOT humane >:(_lgilgt.png',
),
))
def test_format_full(
format_string_directory: str,
format_string_file: str,
expected: str,
reddit_submission: praw.models.Submission):
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
test_formatter = FileNameFormatter(format_string_file, format_string_directory)
result = test_formatter.format_path(test_resource, Path('test'))
assert str(result) == expected
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('format_string_directory', 'format_string_file'), (
('{SUBREDDIT}', '{POSTID}'),
('{SUBREDDIT}', '{UPVOTES}'),
('{SUBREDDIT}', '{UPVOTES}{POSTID}'),
))
def test_format_full_conform(
format_string_directory: str,
format_string_file: str,
reddit_submission: praw.models.Submission):
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
test_formatter = FileNameFormatter(format_string_file, format_string_directory)
test_formatter.format_path(test_resource, Path('test'))
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('format_string_directory', 'format_string_file', 'index', 'expected'), (
('{SUBREDDIT}', '{POSTID}', None, 'test/Mindustry/lgilgt.png'),
('{SUBREDDIT}', '{POSTID}', 1, 'test/Mindustry/lgilgt_1.png'),
('{SUBREDDIT}', '{POSTID}', 2, 'test/Mindustry/lgilgt_2.png'),
('{SUBREDDIT}', '{TITLE}_{POSTID}', 2, 'test/Mindustry/Toxopid that is NOT humane >:(_lgilgt_2.png'),
))
def test_format_full_with_index_suffix(
format_string_directory: str,
format_string_file: str,
index: Optional[int],
expected: str,
reddit_submission: praw.models.Submission):
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
test_formatter = FileNameFormatter(format_string_file, format_string_directory)
result = test_formatter.format_path(test_resource, Path('test'), index)
assert str(result) == expected
def test_format_multiple_resources():
mocks = []
for i in range(1, 5):
new_mock = MagicMock()
new_mock.url = 'https://example.com/test.png'
new_mock.extension = '.png'
new_mock.source_submission.title = 'test'
new_mock.source_submission.__class__ = praw.models.Submission
mocks.append(new_mock)
test_formatter = FileNameFormatter('{TITLE}', '')
results = test_formatter.format_resource_paths(mocks, Path('.'))
results = set([str(res[0]) for res in results])
assert results == {'test_1.png', 'test_2.png', 'test_3.png', 'test_4.png'}
@pytest.mark.parametrize(('test_filename', 'test_ending'), (
('A' * 300, '.png'),
('A' * 300, '_1.png'),
('a' * 300, '_1000.jpeg'),
('😍💕✨' * 100, '_1.png'),
))
def test_limit_filename_length(test_filename: str, test_ending: str):
result = FileNameFormatter._limit_file_name_length(test_filename, test_ending)
assert len(result) <= 255
assert len(result.encode('utf-8')) <= 255
assert isinstance(result, str)
@pytest.mark.parametrize(('test_filename', 'test_ending', 'expected_end'), (
('test_aaaaaa', '_1.png', 'test_aaaaaa_1.png'),
('test_aataaa', '_1.png', 'test_aataaa_1.png'),
('test_abcdef', '_1.png', 'test_abcdef_1.png'),
('test_aaaaaa', '.png', 'test_aaaaaa.png'),
('test', '_1.png', 'test_1.png'),
('test_m1hqw6', '_1.png', 'test_m1hqw6_1.png'),
('A' * 300 + '_bbbccc', '.png', '_bbbccc.png'),
('A' * 300 + '_bbbccc', '_1000.jpeg', '_bbbccc_1000.jpeg'),
('😍💕✨' * 100 + '_aaa1aa', '_1.png', '_aaa1aa_1.png'),
))
def test_preserve_id_append_when_shortening(test_filename: str, test_ending: str, expected_end: str):
result = FileNameFormatter._limit_file_name_length(test_filename, test_ending)
assert len(result) <= 255
assert len(result.encode('utf-8')) <= 255
assert isinstance(result, str)
assert result.endswith(expected_end)
def test_shorten_filenames(submission: MagicMock, tmp_path: Path):
submission.title = 'A' * 300
submission.author.name = 'test'
submission.subreddit.display_name = 'test'
submission.id = 'BBBBBB'
test_resource = Resource(submission, 'www.example.com/empty', '.jpeg')
test_formatter = FileNameFormatter('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}')
result = test_formatter.format_path(test_resource, tmp_path)
result.parent.mkdir(parents=True)
result.touch()
@pytest.mark.parametrize(('test_string', 'expected'), (
('test', 'test'),
('test😍', 'test'),
('test.png', 'test.png'),
('test*', 'test'),
('test**', 'test'),
('test?*', 'test'),
('test_???.png', 'test_.png'),
('test_???😍.png', 'test_.png'),
))
def test_format_file_name_for_windows(test_string: str, expected: str):
result = FileNameFormatter._format_for_windows(test_string)
assert result == expected
@pytest.mark.parametrize(('test_string', 'expected'), (
('test', 'test'),
('test😍', 'test'),
('😍', ''),
))
def test_strip_emojies(test_string: str, expected: str):
result = FileNameFormatter._strip_emojis(test_string)
assert result == expected
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'expected'), (
('mfuteh', {
'title': 'Why Do Interviewers Ask Linked List Questions?',
'redditor': 'mjgardner',
}),
))
def test_generate_dict_for_submission(test_submission_id: str, expected: dict, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id)
result = FileNameFormatter._generate_name_dict_from_submission(test_submission)
assert all([result.get(key) == expected[key] for key in expected.keys()])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_comment_id', 'expected'), (
('gsq0yuw', {
'title': 'Why Do Interviewers Ask Linked List Questions?',
'redditor': 'Doctor-Dapper',
'postid': 'gsq0yuw',
'flair': '',
}),
))
def test_generate_dict_for_comment(test_comment_id: str, expected: dict, reddit_instance: praw.Reddit):
test_comment = reddit_instance.comment(id=test_comment_id)
result = FileNameFormatter._generate_name_dict_from_comment(test_comment)
assert all([result.get(key) == expected[key] for key in expected.keys()])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme', 'test_comment_id', 'expected_name'), (
('{POSTID}', '', 'gsoubde', 'gsoubde.json'),
('{REDDITOR}_{POSTID}', '', 'gsoubde', 'DELETED_gsoubde.json'),
))
def test_format_archive_entry_comment(
test_file_scheme: str,
test_folder_scheme: str,
test_comment_id: str,
expected_name: str,
tmp_path: Path,
reddit_instance: praw.Reddit,
):
test_comment = reddit_instance.comment(id=test_comment_id)
test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme)
test_entry = Resource(test_comment, '', '.json')
result = test_formatter.format_path(test_entry, tmp_path)
assert result.name == expected_name
@pytest.mark.parametrize(('test_folder_scheme', 'expected'), (
('{REDDITOR}/{SUBREDDIT}', 'person/randomreddit'),
('{POSTID}/{SUBREDDIT}/{REDDITOR}', '12345/randomreddit/person'),
))
def test_multilevel_folder_scheme(
test_folder_scheme: str,
expected: str,
tmp_path: Path,
submission: MagicMock,
):
test_formatter = FileNameFormatter('{POSTID}', test_folder_scheme)
test_resource = MagicMock()
test_resource.source_submission = submission
test_resource.extension = '.png'
result = test_formatter.format_path(test_resource, tmp_path)
result = result.relative_to(tmp_path)
assert str(result.parent) == expected
assert len(result.parents) == (len(expected.split('/')) + 1)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.parametrize(('test_submission_id', 'test_file_scheme', 'expected'), (
('mecwk7', '{TITLE}', 'My cats paws are so cute'), # Unicode escape in title
))
def test_edge_case_names(test_submission_id: str, test_file_scheme: str, expected: str, reddit_instance: praw.Reddit):
test_submission = reddit_instance.submission(id=test_submission_id)
result = FileNameFormatter._format_name(test_submission, test_file_scheme)
assert result == expected

View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
# coding=utf-8
import re
from pathlib import Path
import pytest
from click.testing import CliRunner
from bdfr.__main__ import cli
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['-s', 'Mindustry', '-L', 1],
['-s', 'r/Mindustry', '-L', 1],
['-s', 'r/mindustry', '-L', 1],
['-s', 'mindustry', '-L', 1],
['-s', 'https://www.reddit.com/r/TrollXChromosomes/', '-L', 1],
['-s', 'r/TrollXChromosomes/', '-L', 1],
['-s', 'TrollXChromosomes/', '-L', 1],
['-s', 'trollxchromosomes', '-L', 1],
['-s', 'trollxchromosomes,mindustry,python', '-L', 1],
['-s', 'trollxchromosomes, mindustry, python', '-L', 1],
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day'],
['-s', 'trollxchromosomes', '-L', 1, '--sort', 'new'],
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--sort', 'new'],
['-s', 'trollxchromosomes', '-L', 1, '--search', 'women'],
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--search', 'women'],
['-s', 'trollxchromosomes', '-L', 1, '--sort', 'new', '--search', 'women'],
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--sort', 'new', '--search', 'women'],
))
def test_cli_download_subreddits(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Added submissions from subreddit ' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['-l', 'm2601g'],
['-l', 'https://www.reddit.com/r/TrollXChromosomes/comments/m2601g/its_a_step_in_the_right_direction/'],
['-l', 'm3hxzd'], # Really long title used to overflow filename limit
['-l', 'm3kua3'], # Has a deleted user
['-l', 'm5bqkf'], # Resource leading to a 404
))
def test_cli_download_links(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10],
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--sort', 'rising'],
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--time', 'week'],
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--time', 'week', '--sort', 'rising'],
))
def test_cli_download_multireddit(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Added submissions from multireddit ' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'helen_darten', '-m', 'xxyyzzqwerty', '-L', 10],
))
def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Failed to get submissions for multireddit' in result.output
assert 'received 404 HTTP response' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'me', '--upvoted', '--authenticate', '-L', 10],
['--user', 'me', '--saved', '--authenticate', '-L', 10],
['--user', 'me', '--submitted', '--authenticate', '-L', 10],
['--user', 'djnish', '--submitted', '-L', 10],
['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'],
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'],
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial', '--time', 'month'],
))
def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Downloaded submission ' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'me', '-L', 10, '--folder-scheme', ''],
))
def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'To use "me" as a user, an authenticated Reddit instance must be used' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--subreddit', 'python', '-L', 10, '--search-existing'],
))
def test_cli_download_search_existing(test_args: list[str], tmp_path: Path):
Path(tmp_path, 'test.txt').touch()
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Calculating hashes for' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--subreddit', 'tumblr', '-L', '25', '--skip', 'png', '--skip', 'jpg'],
))
def test_cli_download_download_filters(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Download filter removed submission' in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.slow
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--subreddit', 'all', '-L', '100', '--sort', 'new'],
))
def test_cli_download_long(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['-l', 'gstd4hk'],
['-l', 'm2601g'],
))
def test_cli_archive_single(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['archive', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert re.search(r'Writing entry .*? to file in .*? format', result.output)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--subreddit', 'Mindustry', '-L', 25],
['--subreddit', 'Mindustry', '-L', 25, '--format', 'xml'],
['--subreddit', 'Mindustry', '-L', 25, '--format', 'yaml'],
['--subreddit', 'Mindustry', '-L', 25, '--sort', 'new'],
['--subreddit', 'Mindustry', '-L', 25, '--time', 'day'],
['--subreddit', 'Mindustry', '-L', 25, '--time', 'day', '--sort', 'new'],
))
def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['archive', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert re.search(r'Writing entry .*? to file in .*? format', result.output)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'me', '--authenticate', '--all-comments', '-L', '10'],
))
def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['archive', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.slow
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--subreddit', 'all', '-L', 100],
['--subreddit', 'all', '-L', 100, '--sort', 'new'],
))
def test_cli_archive_long(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['archive', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert re.search(r'Writing entry .*? to file in .*? format', result.output)
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.slow
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--user', 'sdclhgsolgjeroij', '--submitted', '-L', 10],
['--user', 'me', '--upvoted', '-L', 10],
['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10],
))
def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.slow
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--time', 'random'],
['--sort', 'random'],
))
def test_cli_download_hard_fail(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code != 0
def test_cli_download_use_default_config(tmp_path: Path):
runner = CliRunner()
test_args = ['download', '-vv', str(tmp_path)]
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['-l', 'm2601g', '--exclude-id', 'm2601g'],
))
def test_cli_download_links_exclusion(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'in exclusion list' in result.output
assert 'Downloaded submission ' not in result.output
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
@pytest.mark.parametrize('test_args', (
['--file-scheme', '{TITLE}'],
['--file-scheme', '{TITLE}_test_{SUBREDDIT}'],
))
def test_cli_file_scheme_warning(test_args: list[str], tmp_path: Path):
runner = CliRunner()
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
result = runner.invoke(cli, test_args)
assert result.exit_code == 0
assert 'Some files might not be downloaded due to name conflicts' in result.output

71
bdfr/tests/test_oauth2.py Normal file
View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python3
# coding=utf-8
import configparser
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from bdfr.exceptions import BulkDownloaderException
from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
@pytest.fixture()
def example_config() -> configparser.ConfigParser:
out = configparser.ConfigParser()
config_dict = {'DEFAULT': {'user_token': 'example'}}
out.read_dict(config_dict)
return out
@pytest.mark.online
@pytest.mark.parametrize('test_scopes', (
{'history', },
{'history', 'creddits'},
{'account', 'flair'},
{'*', },
))
def test_check_scopes(test_scopes: set[str]):
OAuth2Authenticator._check_scopes(test_scopes)
@pytest.mark.parametrize(('test_scopes', 'expected'), (
('history', {'history', }),
('history creddits', {'history', 'creddits'}),
('history, creddits, account', {'history', 'creddits', 'account'}),
('history,creddits,account,flair', {'history', 'creddits', 'account', 'flair'}),
))
def test_split_scopes(test_scopes: str, expected: set[str]):
result = OAuth2Authenticator.split_scopes(test_scopes)
assert result == expected
@pytest.mark.online
@pytest.mark.parametrize('test_scopes', (
{'random', },
{'scope', 'another_scope'},
))
def test_check_scopes_bad(test_scopes: set[str]):
with pytest.raises(BulkDownloaderException):
OAuth2Authenticator._check_scopes(test_scopes)
def test_token_manager_read(example_config: configparser.ConfigParser):
mock_authoriser = MagicMock()
mock_authoriser.refresh_token = None
test_manager = OAuth2TokenManager(example_config, MagicMock())
test_manager.pre_refresh_callback(mock_authoriser)
assert mock_authoriser.refresh_token == example_config.get('DEFAULT', 'user_token')
def test_token_manager_write(example_config: configparser.ConfigParser, tmp_path: Path):
test_path = tmp_path / 'test.cfg'
mock_authoriser = MagicMock()
mock_authoriser.refresh_token = 'changed_token'
test_manager = OAuth2TokenManager(example_config, test_path)
test_manager.post_refresh_callback(mock_authoriser)
assert example_config.get('DEFAULT', 'user_token') == 'changed_token'
with open(test_path, 'r') as file:
file_contents = file.read()
assert 'user_token = changed_token' in file_contents

View File

@@ -0,0 +1,32 @@
#!/usr/bin/env python3
# coding=utf-8
import pytest
from unittest.mock import MagicMock
from bdfr.resource import Resource
@pytest.mark.parametrize(('test_url', 'expected'), (
('test.png', '.png'),
('another.mp4', '.mp4'),
('test.jpeg', '.jpeg'),
('http://www.random.com/resource.png', '.png'),
('https://www.resource.com/test/example.jpg', '.jpg'),
('hard.png.mp4', '.mp4'),
('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'),
))
def test_resource_get_extension(test_url: str, expected: str):
test_resource = Resource(MagicMock(), test_url)
result = test_resource._determine_extension()
assert result == expected
@pytest.mark.online
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
('https://www.iana.org/_img/2013.1/iana-logo-header.svg', '426b3ac01d3584c820f3b7f5985d6623'),
))
def test_download_online_resource(test_url: str, expected_hash: str):
test_resource = Resource(MagicMock(), test_url)
test_resource.download(120)
assert test_resource.hash.hexdigest() == expected_hash