Rename module
This commit is contained in:
0
bdfr/__init__.py
Normal file
0
bdfr/__init__.py
Normal file
117
bdfr/__main__.py
Normal file
117
bdfr/__main__.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
from bdfr.archiver import Archiver
|
||||
from bdfr.configuration import Configuration
|
||||
from bdfr.downloader import RedditDownloader
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
_common_options = [
|
||||
click.argument('directory', type=str),
|
||||
click.option('--config', type=str, default=None),
|
||||
click.option('-v', '--verbose', default=None, count=True),
|
||||
click.option('-l', '--link', multiple=True, default=None, type=str),
|
||||
click.option('-s', '--subreddit', multiple=True, default=None, type=str),
|
||||
click.option('-m', '--multireddit', multiple=True, default=None, type=str),
|
||||
click.option('-L', '--limit', default=None, type=int),
|
||||
click.option('--authenticate', is_flag=True, default=None),
|
||||
click.option('--submitted', is_flag=True, default=None),
|
||||
click.option('--upvoted', is_flag=True, default=None),
|
||||
click.option('--saved', is_flag=True, default=None),
|
||||
click.option('--search', default=None, type=str),
|
||||
click.option('-u', '--user', type=str, default=None),
|
||||
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
|
||||
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new',
|
||||
'controversial', 'rising', 'relevance')), default=None),
|
||||
]
|
||||
|
||||
|
||||
def _add_common_options(func):
|
||||
for opt in _common_options:
|
||||
func = opt(func)
|
||||
return func
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@cli.command('download')
|
||||
@click.option('--exclude-id', default=None, multiple=True)
|
||||
@click.option('--exclude-id-file', default=None, multiple=True)
|
||||
@click.option('--file-scheme', default=None, type=str)
|
||||
@click.option('--folder-scheme', default=None, type=str)
|
||||
@click.option('--make-hard-links', is_flag=True, default=None)
|
||||
@click.option('--max-wait-time', type=int, default=None)
|
||||
@click.option('--no-dupes', is_flag=True, default=None)
|
||||
@click.option('--search-existing', is_flag=True, default=None)
|
||||
@click.option('--skip', default=None, multiple=True)
|
||||
@click.option('--skip-domain', default=None, multiple=True)
|
||||
@_add_common_options
|
||||
@click.pass_context
|
||||
def cli_download(context: click.Context, **_):
|
||||
config = Configuration()
|
||||
config.process_click_arguments(context)
|
||||
setup_logging(config.verbose)
|
||||
try:
|
||||
reddit_downloader = RedditDownloader(config)
|
||||
reddit_downloader.download()
|
||||
except Exception:
|
||||
logger.exception('Downloader exited unexpectedly')
|
||||
raise
|
||||
else:
|
||||
logger.info('Program complete')
|
||||
|
||||
|
||||
@cli.command('archive')
|
||||
@_add_common_options
|
||||
@click.option('--all-comments', is_flag=True, default=None)
|
||||
@click.option('-f,', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None)
|
||||
@click.pass_context
|
||||
def cli_archive(context: click.Context, **_):
|
||||
config = Configuration()
|
||||
config.process_click_arguments(context)
|
||||
setup_logging(config.verbose)
|
||||
try:
|
||||
reddit_archiver = Archiver(config)
|
||||
reddit_archiver.download()
|
||||
except Exception:
|
||||
logger.exception('Downloader exited unexpectedly')
|
||||
raise
|
||||
else:
|
||||
logger.info('Program complete')
|
||||
|
||||
|
||||
def setup_logging(verbosity: int):
|
||||
class StreamExceptionFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
result = not (record.levelno == logging.ERROR and record.exc_info)
|
||||
return result
|
||||
|
||||
logger.setLevel(1)
|
||||
stream = logging.StreamHandler(sys.stdout)
|
||||
stream.addFilter(StreamExceptionFilter())
|
||||
|
||||
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
|
||||
stream.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(stream)
|
||||
if verbosity <= 0:
|
||||
stream.setLevel(logging.INFO)
|
||||
elif verbosity == 1:
|
||||
stream.setLevel(logging.DEBUG)
|
||||
else:
|
||||
stream.setLevel(9)
|
||||
logging.getLogger('praw').setLevel(logging.CRITICAL)
|
||||
logging.getLogger('prawcore').setLevel(logging.CRITICAL)
|
||||
logging.getLogger('urllib3').setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
2
bdfr/archive_entry/__init__.py
Normal file
2
bdfr/archive_entry/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
36
bdfr/archive_entry/base_archive_entry.py
Normal file
36
bdfr/archive_entry/base_archive_entry.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from praw.models import Comment, Submission
|
||||
|
||||
|
||||
class BaseArchiveEntry(ABC):
|
||||
def __init__(self, source: (Comment, Submission)):
|
||||
self.source = source
|
||||
self.post_details: dict = {}
|
||||
|
||||
@abstractmethod
|
||||
def compile(self) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _convert_comment_to_dict(in_comment: Comment) -> dict:
|
||||
out_dict = {
|
||||
'author': in_comment.author.name if in_comment.author else 'DELETED',
|
||||
'id': in_comment.id,
|
||||
'score': in_comment.score,
|
||||
'subreddit': in_comment.subreddit.display_name,
|
||||
'submission': in_comment.submission.id,
|
||||
'stickied': in_comment.stickied,
|
||||
'body': in_comment.body,
|
||||
'is_submitter': in_comment.is_submitter,
|
||||
'created_utc': in_comment.created_utc,
|
||||
'parent_id': in_comment.parent_id,
|
||||
'replies': [],
|
||||
}
|
||||
in_comment.replies.replace_more(0)
|
||||
for reply in in_comment.replies:
|
||||
out_dict['replies'].append(BaseArchiveEntry._convert_comment_to_dict(reply))
|
||||
return out_dict
|
||||
21
bdfr/archive_entry/comment_archive_entry.py
Normal file
21
bdfr/archive_entry/comment_archive_entry.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
|
||||
import praw.models
|
||||
|
||||
from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommentArchiveEntry(BaseArchiveEntry):
|
||||
def __init__(self, comment: praw.models.Comment):
|
||||
super(CommentArchiveEntry, self).__init__(comment)
|
||||
|
||||
def compile(self) -> dict:
|
||||
self.source.refresh()
|
||||
self.post_details = self._convert_comment_to_dict(self.source)
|
||||
self.post_details['submission_title'] = self.source.submission.title
|
||||
return self.post_details
|
||||
47
bdfr/archive_entry/submission_archive_entry.py
Normal file
47
bdfr/archive_entry/submission_archive_entry.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
|
||||
import praw.models
|
||||
|
||||
from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubmissionArchiveEntry(BaseArchiveEntry):
|
||||
def __init__(self, submission: praw.models.Submission):
|
||||
super(SubmissionArchiveEntry, self).__init__(submission)
|
||||
|
||||
def compile(self) -> dict:
|
||||
comments = self._get_comments()
|
||||
self._get_post_details()
|
||||
out = self.post_details
|
||||
out['comments'] = comments
|
||||
return out
|
||||
|
||||
def _get_post_details(self):
|
||||
self.post_details = {
|
||||
'title': self.source.title,
|
||||
'name': self.source.name,
|
||||
'url': self.source.url,
|
||||
'selftext': self.source.selftext,
|
||||
'score': self.source.score,
|
||||
'upvote_ratio': self.source.upvote_ratio,
|
||||
'permalink': self.source.permalink,
|
||||
'id': self.source.id,
|
||||
'author': self.source.author.name if self.source.author else 'DELETED',
|
||||
'link_flair_text': self.source.link_flair_text,
|
||||
'num_comments': self.source.num_comments,
|
||||
'over_18': self.source.over_18,
|
||||
'created_utc': self.source.created_utc,
|
||||
}
|
||||
|
||||
def _get_comments(self) -> list[dict]:
|
||||
logger.debug(f'Retrieving full comment tree for submission {self.source.id}')
|
||||
comments = []
|
||||
self.source.comments.replace_more(0)
|
||||
for top_level_comment in self.source.comments:
|
||||
comments.append(self._convert_comment_to_dict(top_level_comment))
|
||||
return comments
|
||||
96
bdfr/archiver.py
Normal file
96
bdfr/archiver.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Iterator
|
||||
|
||||
import dict2xml
|
||||
import praw.models
|
||||
import yaml
|
||||
|
||||
from bdfr.archive_entry.base_archive_entry import BaseArchiveEntry
|
||||
from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
|
||||
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
|
||||
from bdfr.configuration import Configuration
|
||||
from bdfr.downloader import RedditDownloader
|
||||
from bdfr.exceptions import ArchiverError
|
||||
from bdfr.resource import Resource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Archiver(RedditDownloader):
|
||||
def __init__(self, args: Configuration):
|
||||
super(Archiver, self).__init__(args)
|
||||
|
||||
def download(self):
|
||||
for generator in self.reddit_lists:
|
||||
for submission in generator:
|
||||
logger.debug(f'Attempting to archive submission {submission.id}')
|
||||
self._write_entry(submission)
|
||||
|
||||
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||
supplied_submissions = []
|
||||
for sub_id in self.args.link:
|
||||
if len(sub_id) == 6:
|
||||
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
|
||||
elif re.match(r'^\w{7}$', sub_id):
|
||||
supplied_submissions.append(self.reddit_instance.comment(id=sub_id))
|
||||
else:
|
||||
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
|
||||
return [supplied_submissions]
|
||||
|
||||
def _get_user_data(self) -> list[Iterator]:
|
||||
results = super(Archiver, self)._get_user_data()
|
||||
if self.args.user and self.args.all_comments:
|
||||
sort = self._determine_sort_function()
|
||||
logger.debug(f'Retrieving comments of user {self.args.user}')
|
||||
results.append(sort(self.reddit_instance.redditor(self.args.user).comments, limit=self.args.limit))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _pull_lever_entry_factory(praw_item: (praw.models.Submission, praw.models.Comment)) -> BaseArchiveEntry:
|
||||
if isinstance(praw_item, praw.models.Submission):
|
||||
return SubmissionArchiveEntry(praw_item)
|
||||
elif isinstance(praw_item, praw.models.Comment):
|
||||
return CommentArchiveEntry(praw_item)
|
||||
else:
|
||||
raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}')
|
||||
|
||||
def _write_entry(self, praw_item: (praw.models.Submission, praw.models.Comment)):
|
||||
archive_entry = self._pull_lever_entry_factory(praw_item)
|
||||
if self.args.format == 'json':
|
||||
self._write_entry_json(archive_entry)
|
||||
elif self.args.format == 'xml':
|
||||
self._write_entry_xml(archive_entry)
|
||||
elif self.args.format == 'yaml':
|
||||
self._write_entry_yaml(archive_entry)
|
||||
else:
|
||||
raise ArchiverError(f'Unknown format {self.args.format} given')
|
||||
logger.info(f'Record for entry item {praw_item.id} written to disk')
|
||||
|
||||
def _write_entry_json(self, entry: BaseArchiveEntry):
|
||||
resource = Resource(entry.source, '', '.json')
|
||||
content = json.dumps(entry.compile())
|
||||
self._write_content_to_disk(resource, content)
|
||||
|
||||
def _write_entry_xml(self, entry: BaseArchiveEntry):
|
||||
resource = Resource(entry.source, '', '.xml')
|
||||
content = dict2xml.dict2xml(entry.compile(), wrap='root')
|
||||
self._write_content_to_disk(resource, content)
|
||||
|
||||
def _write_entry_yaml(self, entry: BaseArchiveEntry):
|
||||
resource = Resource(entry.source, '', '.yaml')
|
||||
content = yaml.dump(entry.compile())
|
||||
self._write_content_to_disk(resource, content)
|
||||
|
||||
def _write_content_to_disk(self, resource: Resource, content: str):
|
||||
file_path = self.file_name_formatter.format_path(resource, self.download_directory)
|
||||
file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(file_path, 'w') as file:
|
||||
logger.debug(
|
||||
f'Writing entry {resource.source_submission.id} to file in {resource.extension[1:].upper()}'
|
||||
f' format at {file_path}')
|
||||
file.write(content)
|
||||
46
bdfr/configuration.py
Normal file
46
bdfr/configuration.py
Normal file
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from argparse import Namespace
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
|
||||
class Configuration(Namespace):
|
||||
def __init__(self):
|
||||
super(Configuration, self).__init__()
|
||||
self.authenticate = False
|
||||
self.config = None
|
||||
self.directory: str = '.'
|
||||
self.exclude_id = []
|
||||
self.exclude_id_file = []
|
||||
self.limit: Optional[int] = None
|
||||
self.link: list[str] = []
|
||||
self.max_wait_time = None
|
||||
self.multireddit: list[str] = []
|
||||
self.no_dupes: bool = False
|
||||
self.saved: bool = False
|
||||
self.search: Optional[str] = None
|
||||
self.search_existing: bool = False
|
||||
self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}'
|
||||
self.folder_scheme: str = '{SUBREDDIT}'
|
||||
self.skip: list[str] = []
|
||||
self.skip_domain: list[str] = []
|
||||
self.sort: str = 'hot'
|
||||
self.submitted: bool = False
|
||||
self.subreddit: list[str] = []
|
||||
self.time: str = 'all'
|
||||
self.upvoted: bool = False
|
||||
self.user: Optional[str] = None
|
||||
self.verbose: int = 0
|
||||
self.make_hard_links = False
|
||||
|
||||
# Archiver-specific options
|
||||
self.format = 'json'
|
||||
self.all_comments = False
|
||||
|
||||
def process_click_arguments(self, context: click.Context):
|
||||
for arg_key in context.params.keys():
|
||||
if arg_key in vars(self) and context.params[arg_key] is not None:
|
||||
vars(self)[arg_key] = context.params[arg_key]
|
||||
6
bdfr/default_config.cfg
Normal file
6
bdfr/default_config.cfg
Normal file
@@ -0,0 +1,6 @@
|
||||
[DEFAULT]
|
||||
client_id = U-6gk4ZCh3IeNQ
|
||||
client_secret = 7CZHY6AmKweZME5s50SfDGylaPg
|
||||
scopes = identity, history, read, save
|
||||
backup_log_count = 3
|
||||
max_wait_time = 120
|
||||
44
bdfr/download_filter.py
Normal file
44
bdfr/download_filter.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadFilter:
|
||||
def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None):
|
||||
self.excluded_extensions = excluded_extensions
|
||||
self.excluded_domains = excluded_domains
|
||||
|
||||
def check_url(self, url: str) -> bool:
|
||||
"""Return whether a URL is allowed or not"""
|
||||
if not self._check_extension(url):
|
||||
return False
|
||||
elif not self._check_domain(url):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _check_extension(self, url: str) -> bool:
|
||||
if not self.excluded_extensions:
|
||||
return True
|
||||
combined_extensions = '|'.join(self.excluded_extensions)
|
||||
pattern = re.compile(r'.*({})$'.format(combined_extensions))
|
||||
if re.match(pattern, url):
|
||||
logger.log(9, f'Url "{url}" matched with "{str(pattern)}"')
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _check_domain(self, url: str) -> bool:
|
||||
if not self.excluded_domains:
|
||||
return True
|
||||
combined_domains = '|'.join(self.excluded_domains)
|
||||
pattern = re.compile(r'https?://.*({}).*'.format(combined_domains))
|
||||
if re.match(pattern, url):
|
||||
logger.log(9, f'Url "{url}" matched with "{str(pattern)}"')
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
437
bdfr/downloader.py
Normal file
437
bdfr/downloader.py
Normal file
@@ -0,0 +1,437 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import configparser
|
||||
import hashlib
|
||||
import importlib.resources
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import appdirs
|
||||
import praw
|
||||
import praw.exceptions
|
||||
import praw.models
|
||||
import prawcore
|
||||
|
||||
import bdfr.exceptions as errors
|
||||
from bdfr.configuration import Configuration
|
||||
from bdfr.download_filter import DownloadFilter
|
||||
from bdfr.file_name_formatter import FileNameFormatter
|
||||
from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.download_factory import DownloadFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _calc_hash(existing_file: Path):
|
||||
with open(existing_file, 'rb') as file:
|
||||
file_hash = hashlib.md5(file.read()).hexdigest()
|
||||
return existing_file, file_hash
|
||||
|
||||
|
||||
class RedditTypes:
|
||||
class SortType(Enum):
|
||||
HOT = auto()
|
||||
RISING = auto()
|
||||
CONTROVERSIAL = auto()
|
||||
NEW = auto()
|
||||
RELEVENCE = auto()
|
||||
|
||||
class TimeType(Enum):
|
||||
HOUR = auto()
|
||||
DAY = auto()
|
||||
WEEK = auto()
|
||||
MONTH = auto()
|
||||
YEAR = auto()
|
||||
ALL = auto()
|
||||
|
||||
|
||||
class RedditDownloader:
|
||||
def __init__(self, args: Configuration):
|
||||
self.args = args
|
||||
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR')
|
||||
self.run_time = datetime.now().isoformat()
|
||||
self._setup_internal_objects()
|
||||
|
||||
self.reddit_lists = self._retrieve_reddit_lists()
|
||||
|
||||
def _setup_internal_objects(self):
|
||||
self._determine_directories()
|
||||
self._load_config()
|
||||
self._create_file_logger()
|
||||
|
||||
self._read_config()
|
||||
|
||||
self.download_filter = self._create_download_filter()
|
||||
logger.log(9, 'Created download filter')
|
||||
self.time_filter = self._create_time_filter()
|
||||
logger.log(9, 'Created time filter')
|
||||
self.sort_filter = self._create_sort_filter()
|
||||
logger.log(9, 'Created sort filter')
|
||||
self.file_name_formatter = self._create_file_name_formatter()
|
||||
logger.log(9, 'Create file name formatter')
|
||||
|
||||
self._create_reddit_instance()
|
||||
self._resolve_user_name()
|
||||
|
||||
self.excluded_submission_ids = self._read_excluded_ids()
|
||||
|
||||
if self.args.search_existing:
|
||||
self.master_hash_list = self.scan_existing_files(self.download_directory)
|
||||
else:
|
||||
self.master_hash_list = {}
|
||||
self.authenticator = self._create_authenticator()
|
||||
logger.log(9, 'Created site authenticator')
|
||||
|
||||
def _read_config(self):
|
||||
"""Read any cfg values that need to be processed"""
|
||||
if self.args.max_wait_time is None:
|
||||
if not self.cfg_parser.has_option('DEFAULT', 'max_wait_time'):
|
||||
self.cfg_parser.set('DEFAULT', 'max_wait_time', '120')
|
||||
logger.log(9, 'Wrote default download wait time download to config file')
|
||||
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time')
|
||||
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds')
|
||||
# Update config on disk
|
||||
with open(self.config_location, 'w') as file:
|
||||
self.cfg_parser.write(file)
|
||||
|
||||
def _create_reddit_instance(self):
|
||||
if self.args.authenticate:
|
||||
logger.debug('Using authenticated Reddit instance')
|
||||
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
|
||||
logger.log(9, 'Commencing OAuth2 authentication')
|
||||
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
|
||||
scopes = OAuth2Authenticator.split_scopes(scopes)
|
||||
oauth2_authenticator = OAuth2Authenticator(
|
||||
scopes,
|
||||
self.cfg_parser.get('DEFAULT', 'client_id'),
|
||||
self.cfg_parser.get('DEFAULT', 'client_secret'),
|
||||
)
|
||||
token = oauth2_authenticator.retrieve_new_token()
|
||||
self.cfg_parser['DEFAULT']['user_token'] = token
|
||||
with open(self.config_location, 'w') as file:
|
||||
self.cfg_parser.write(file, True)
|
||||
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
|
||||
|
||||
self.authenticated = True
|
||||
self.reddit_instance = praw.Reddit(
|
||||
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
|
||||
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
|
||||
user_agent=socket.gethostname(),
|
||||
token_manager=token_manager,
|
||||
)
|
||||
else:
|
||||
logger.debug('Using unauthenticated Reddit instance')
|
||||
self.authenticated = False
|
||||
self.reddit_instance = praw.Reddit(
|
||||
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
|
||||
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
|
||||
user_agent=socket.gethostname(),
|
||||
)
|
||||
|
||||
def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
|
||||
master_list = []
|
||||
master_list.extend(self._get_subreddits())
|
||||
logger.log(9, 'Retrieved subreddits')
|
||||
master_list.extend(self._get_multireddits())
|
||||
logger.log(9, 'Retrieved multireddits')
|
||||
master_list.extend(self._get_user_data())
|
||||
logger.log(9, 'Retrieved user data')
|
||||
master_list.extend(self._get_submissions_from_link())
|
||||
logger.log(9, 'Retrieved submissions for given links')
|
||||
return master_list
|
||||
|
||||
def _determine_directories(self):
|
||||
self.download_directory = Path(self.args.directory).resolve().expanduser()
|
||||
self.config_directory = Path(self.config_directories.user_config_dir)
|
||||
|
||||
self.download_directory.mkdir(exist_ok=True, parents=True)
|
||||
self.config_directory.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def _load_config(self):
|
||||
self.cfg_parser = configparser.ConfigParser()
|
||||
if self.args.config:
|
||||
if (cfg_path := Path(self.args.config)).exists():
|
||||
self.cfg_parser.read(cfg_path)
|
||||
self.config_location = cfg_path
|
||||
return
|
||||
possible_paths = [
|
||||
Path('./config.cfg'),
|
||||
Path('./default_config.cfg'),
|
||||
Path(self.config_directory, 'config.cfg'),
|
||||
Path(self.config_directory, 'default_config.cfg'),
|
||||
]
|
||||
self.config_location = None
|
||||
for path in possible_paths:
|
||||
if path.resolve().expanduser().exists():
|
||||
self.config_location = path
|
||||
logger.debug(f'Loading configuration from {path}')
|
||||
break
|
||||
if not self.config_location:
|
||||
self.config_location = list(importlib.resources.path('bdfr', 'default_config.cfg').gen)[0]
|
||||
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg'))
|
||||
if not self.config_location:
|
||||
raise errors.BulkDownloaderException('Could not find a configuration file to load')
|
||||
self.cfg_parser.read(self.config_location)
|
||||
|
||||
def _create_file_logger(self):
|
||||
main_logger = logging.getLogger()
|
||||
log_path = Path(self.config_directory, 'log_output.txt')
|
||||
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3)
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
log_path,
|
||||
mode='a',
|
||||
backupCount=backup_count,
|
||||
)
|
||||
if log_path.exists():
|
||||
file_handler.doRollover()
|
||||
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(0)
|
||||
|
||||
main_logger.addHandler(file_handler)
|
||||
|
||||
@staticmethod
|
||||
def _sanitise_subreddit_name(subreddit: str) -> str:
|
||||
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)(?:/)?$')
|
||||
match = re.match(pattern, subreddit)
|
||||
if not match:
|
||||
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
|
||||
return match.group(1)
|
||||
|
||||
@staticmethod
|
||||
def _split_args_input(subreddit_entries: list[str]) -> set[str]:
|
||||
all_subreddits = []
|
||||
split_pattern = re.compile(r'[,;]\s?')
|
||||
for entry in subreddit_entries:
|
||||
results = re.split(split_pattern, entry)
|
||||
all_subreddits.extend([RedditDownloader._sanitise_subreddit_name(name) for name in results])
|
||||
return set(all_subreddits)
|
||||
|
||||
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
||||
if self.args.subreddit:
|
||||
out = []
|
||||
sort_function = self._determine_sort_function()
|
||||
for reddit in self._split_args_input(self.args.subreddit):
|
||||
try:
|
||||
reddit = self.reddit_instance.subreddit(reddit)
|
||||
if self.args.search:
|
||||
out.append(
|
||||
reddit.search(
|
||||
self.args.search,
|
||||
sort=self.sort_filter.name.lower(),
|
||||
limit=self.args.limit,
|
||||
))
|
||||
logger.debug(
|
||||
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
|
||||
else:
|
||||
out.append(sort_function(reddit, limit=self.args.limit))
|
||||
logger.debug(f'Added submissions from subreddit {reddit}')
|
||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
||||
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
||||
return out
|
||||
else:
|
||||
return []
|
||||
|
||||
def _resolve_user_name(self):
|
||||
if self.args.user == 'me':
|
||||
if self.authenticated:
|
||||
self.args.user = self.reddit_instance.user.me().name
|
||||
logger.log(9, f'Resolved user to {self.args.user}')
|
||||
else:
|
||||
self.args.user = None
|
||||
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
|
||||
|
||||
def _get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
|
||||
supplied_submissions = []
|
||||
for sub_id in self.args.link:
|
||||
if len(sub_id) == 6:
|
||||
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
|
||||
else:
|
||||
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
|
||||
return [supplied_submissions]
|
||||
|
||||
def _determine_sort_function(self):
|
||||
if self.sort_filter is RedditTypes.SortType.NEW:
|
||||
sort_function = praw.models.Subreddit.new
|
||||
elif self.sort_filter is RedditTypes.SortType.RISING:
|
||||
sort_function = praw.models.Subreddit.rising
|
||||
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
|
||||
sort_function = praw.models.Subreddit.controversial
|
||||
else:
|
||||
sort_function = praw.models.Subreddit.hot
|
||||
return sort_function
|
||||
|
||||
def _get_multireddits(self) -> list[Iterator]:
|
||||
if self.args.multireddit:
|
||||
out = []
|
||||
sort_function = self._determine_sort_function()
|
||||
for multi in self._split_args_input(self.args.multireddit):
|
||||
try:
|
||||
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
||||
if not multi.subreddits:
|
||||
raise errors.BulkDownloaderException
|
||||
out.append(sort_function(multi, limit=self.args.limit))
|
||||
logger.debug(f'Added submissions from multireddit {multi}')
|
||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
|
||||
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
|
||||
return out
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_user_data(self) -> list[Iterator]:
|
||||
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
|
||||
if self.args.user:
|
||||
if not self._check_user_existence(self.args.user):
|
||||
logger.error(f'User {self.args.user} does not exist')
|
||||
return []
|
||||
generators = []
|
||||
sort_function = self._determine_sort_function()
|
||||
if self.args.submitted:
|
||||
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
|
||||
generators.append(
|
||||
sort_function(
|
||||
self.reddit_instance.redditor(self.args.user).submissions,
|
||||
limit=self.args.limit,
|
||||
))
|
||||
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
||||
logger.warning('Accessing user lists requires authentication')
|
||||
else:
|
||||
if self.args.upvoted:
|
||||
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
|
||||
generators.append(self.reddit_instance.redditor(self.args.user).upvoted(limit=self.args.limit))
|
||||
if self.args.saved:
|
||||
logger.debug(f'Retrieving saved posts of user {self.args.user}')
|
||||
generators.append(self.reddit_instance.redditor(self.args.user).saved(limit=self.args.limit))
|
||||
return generators
|
||||
else:
|
||||
logger.warning('A user must be supplied to download user data')
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
def _check_user_existence(self, name: str) -> bool:
|
||||
user = self.reddit_instance.redditor(name=name)
|
||||
try:
|
||||
if not user.id:
|
||||
return False
|
||||
except prawcore.exceptions.NotFound:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _create_file_name_formatter(self) -> FileNameFormatter:
|
||||
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme)
|
||||
|
||||
def _create_time_filter(self) -> RedditTypes.TimeType:
|
||||
try:
|
||||
return RedditTypes.TimeType[self.args.time.upper()]
|
||||
except (KeyError, AttributeError):
|
||||
return RedditTypes.TimeType.ALL
|
||||
|
||||
def _create_sort_filter(self) -> RedditTypes.SortType:
|
||||
try:
|
||||
return RedditTypes.SortType[self.args.sort.upper()]
|
||||
except (KeyError, AttributeError):
|
||||
return RedditTypes.SortType.HOT
|
||||
|
||||
def _create_download_filter(self) -> DownloadFilter:
|
||||
return DownloadFilter(self.args.skip, self.args.skip_domain)
|
||||
|
||||
def _create_authenticator(self) -> SiteAuthenticator:
|
||||
return SiteAuthenticator(self.cfg_parser)
|
||||
|
||||
def download(self):
|
||||
for generator in self.reddit_lists:
|
||||
for submission in generator:
|
||||
if submission.id in self.excluded_submission_ids:
|
||||
logger.debug(f'Submission {submission.id} in exclusion list, skipping')
|
||||
continue
|
||||
else:
|
||||
logger.debug(f'Attempting to download submission {submission.id}')
|
||||
self._download_submission(submission)
|
||||
|
||||
def _download_submission(self, submission: praw.models.Submission):
|
||||
if not isinstance(submission, praw.models.Submission):
|
||||
logger.warning(f'{submission.id} is not a submission')
|
||||
return
|
||||
if not self.download_filter.check_url(submission.url):
|
||||
logger.debug(f'Download filter removed submission {submission.id} with URL {submission.url}')
|
||||
return
|
||||
try:
|
||||
downloader_class = DownloadFactory.pull_lever(submission.url)
|
||||
downloader = downloader_class(submission)
|
||||
logger.debug(f'Using {downloader_class.__name__} with url {submission.url}')
|
||||
except errors.NotADownloadableLinkError as e:
|
||||
logger.error(f'Could not download submission {submission.id}: {e}')
|
||||
return
|
||||
|
||||
try:
|
||||
content = downloader.find_resources(self.authenticator)
|
||||
except errors.SiteDownloaderError as e:
|
||||
logger.error(f'Site {downloader_class.__name__} failed to download submission {submission.id}: {e}')
|
||||
return
|
||||
for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory):
|
||||
if destination.exists():
|
||||
logger.debug(f'File {destination} already exists, continuing')
|
||||
else:
|
||||
try:
|
||||
res.download(self.args.max_wait_time)
|
||||
except errors.BulkDownloaderException as e:
|
||||
logger.error(
|
||||
f'Failed to download resource {res.url} with downloader {downloader_class.__name__}: {e}')
|
||||
return
|
||||
resource_hash = res.hash.hexdigest()
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
if resource_hash in self.master_hash_list:
|
||||
if self.args.no_dupes:
|
||||
logger.info(
|
||||
f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere')
|
||||
return
|
||||
elif self.args.make_hard_links:
|
||||
self.master_hash_list[resource_hash].link_to(destination)
|
||||
logger.info(
|
||||
f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}')
|
||||
return
|
||||
with open(destination, 'wb') as file:
|
||||
file.write(res.content)
|
||||
logger.debug(f'Written file to {destination}')
|
||||
self.master_hash_list[resource_hash] = destination
|
||||
logger.debug(f'Hash added to master list: {resource_hash}')
|
||||
logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}')
|
||||
|
||||
@staticmethod
|
||||
def scan_existing_files(directory: Path) -> dict[str, Path]:
|
||||
files = []
|
||||
for (dirpath, dirnames, filenames) in os.walk(directory):
|
||||
files.extend([Path(dirpath, file) for file in filenames])
|
||||
logger.info(f'Calculating hashes for {len(files)} files')
|
||||
|
||||
pool = Pool(15)
|
||||
results = pool.map(_calc_hash, files)
|
||||
pool.close()
|
||||
|
||||
hash_list = {res[1]: res[0] for res in results}
|
||||
return hash_list
|
||||
|
||||
def _read_excluded_ids(self) -> set[str]:
|
||||
out = []
|
||||
out.extend(self.args.exclude_id)
|
||||
for id_file in self.args.exclude_id_file:
|
||||
id_file = Path(id_file).resolve().expanduser()
|
||||
if not id_file.exists():
|
||||
logger.warning(f'ID exclusion file at {id_file} does not exist')
|
||||
continue
|
||||
with open(id_file, 'r') as file:
|
||||
for line in file:
|
||||
out.append(line.strip())
|
||||
return set(out)
|
||||
28
bdfr/exceptions.py
Normal file
28
bdfr/exceptions.py
Normal file
@@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env
|
||||
|
||||
class BulkDownloaderException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RedditUserError(BulkDownloaderException):
|
||||
pass
|
||||
|
||||
|
||||
class RedditAuthenticationError(RedditUserError):
|
||||
pass
|
||||
|
||||
|
||||
class ArchiverError(BulkDownloaderException):
|
||||
pass
|
||||
|
||||
|
||||
class SiteDownloaderError(BulkDownloaderException):
|
||||
pass
|
||||
|
||||
|
||||
class NotADownloadableLinkError(SiteDownloaderError):
|
||||
pass
|
||||
|
||||
|
||||
class ResourceNotFound(SiteDownloaderError):
|
||||
pass
|
||||
158
bdfr/file_name_formatter.py
Normal file
158
bdfr/file_name_formatter.py
Normal file
@@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
import platform
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from praw.models import Comment, Submission
|
||||
|
||||
from bdfr.exceptions import BulkDownloaderException
|
||||
from bdfr.resource import Resource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileNameFormatter:
|
||||
key_terms = (
|
||||
'date',
|
||||
'flair',
|
||||
'postid',
|
||||
'redditor',
|
||||
'subreddit',
|
||||
'title',
|
||||
'upvotes',
|
||||
)
|
||||
|
||||
def __init__(self, file_format_string: str, directory_format_string: str):
|
||||
if not self.validate_string(file_format_string):
|
||||
raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string')
|
||||
self.file_format_string = file_format_string
|
||||
self.directory_format_string: list[str] = directory_format_string.split('/')
|
||||
|
||||
@staticmethod
|
||||
def _format_name(submission: (Comment, Submission), format_string: str) -> str:
|
||||
if isinstance(submission, Submission):
|
||||
attributes = FileNameFormatter._generate_name_dict_from_submission(submission)
|
||||
elif isinstance(submission, Comment):
|
||||
attributes = FileNameFormatter._generate_name_dict_from_comment(submission)
|
||||
else:
|
||||
raise BulkDownloaderException(f'Cannot name object {type(submission).__name__}')
|
||||
result = format_string
|
||||
for key in attributes.keys():
|
||||
if re.search(fr'(?i).*{{{key}}}.*', result):
|
||||
key_value = attributes.get(key, 'unknown')
|
||||
key_value = bytes(key_value, 'utf-8').decode('unicode-escape')
|
||||
result = re.sub(fr'(?i){{{key}}}', key_value, result,)
|
||||
logger.log(9, f'Found key string {key} in name')
|
||||
|
||||
result = result.replace('/', '')
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
result = FileNameFormatter._format_for_windows(result)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_name_dict_from_submission(submission: Submission) -> dict:
|
||||
submission_attributes = {
|
||||
'title': submission.title,
|
||||
'subreddit': submission.subreddit.display_name,
|
||||
'redditor': submission.author.name if submission.author else 'DELETED',
|
||||
'postid': submission.id,
|
||||
'upvotes': submission.score,
|
||||
'flair': submission.link_flair_text,
|
||||
'date': submission.created_utc
|
||||
}
|
||||
return submission_attributes
|
||||
|
||||
@staticmethod
|
||||
def _generate_name_dict_from_comment(comment: Comment) -> dict:
|
||||
comment_attributes = {
|
||||
'title': comment.submission.title,
|
||||
'subreddit': comment.subreddit.display_name,
|
||||
'redditor': comment.author.name if comment.author else 'DELETED',
|
||||
'postid': comment.id,
|
||||
'upvotes': comment.score,
|
||||
'flair': '',
|
||||
'date': comment.created_utc,
|
||||
}
|
||||
return comment_attributes
|
||||
|
||||
def format_path(
|
||||
self,
|
||||
resource: Resource,
|
||||
destination_directory: Path,
|
||||
index: Optional[int] = None,
|
||||
) -> Path:
|
||||
subfolder = Path(
|
||||
destination_directory,
|
||||
*[self._format_name(resource.source_submission, part) for part in self.directory_format_string]
|
||||
)
|
||||
index = f'_{str(index)}' if index else ''
|
||||
if not resource.extension:
|
||||
raise BulkDownloaderException(f'Resource from {resource.url} has no extension')
|
||||
ending = index + resource.extension
|
||||
file_name = str(self._format_name(resource.source_submission, self.file_format_string))
|
||||
file_name = self._limit_file_name_length(file_name, ending)
|
||||
|
||||
try:
|
||||
file_path = Path(subfolder, file_name)
|
||||
except TypeError:
|
||||
raise BulkDownloaderException(f'Could not determine path name: {subfolder}, {index}, {resource.extension}')
|
||||
return file_path
|
||||
|
||||
@staticmethod
|
||||
def _limit_file_name_length(filename: str, ending: str) -> str:
|
||||
possible_id = re.search(r'((?:_\w{6})?$)', filename)
|
||||
if possible_id:
|
||||
ending = possible_id.group(1) + ending
|
||||
filename = filename[:possible_id.start()]
|
||||
max_length_chars = 255 - len(ending)
|
||||
max_length_bytes = 255 - len(ending.encode('utf-8'))
|
||||
while len(filename) > max_length_chars or len(filename.encode('utf-8')) > max_length_bytes:
|
||||
filename = filename[:-1]
|
||||
return filename + ending
|
||||
|
||||
def format_resource_paths(
|
||||
self,
|
||||
resources: list[Resource],
|
||||
destination_directory: Path,
|
||||
) -> list[tuple[Path, Resource]]:
|
||||
out = []
|
||||
if len(resources) == 1:
|
||||
out.append((self.format_path(resources[0], destination_directory, None), resources[0]))
|
||||
else:
|
||||
for i, res in enumerate(resources, start=1):
|
||||
logger.log(9, f'Formatting filename with index {i}')
|
||||
out.append((self.format_path(res, destination_directory, i), res))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def validate_string(test_string: str) -> bool:
|
||||
if not test_string:
|
||||
return False
|
||||
result = any([f'{{{key}}}' in test_string.lower() for key in FileNameFormatter.key_terms])
|
||||
if result:
|
||||
if 'POSTID' not in test_string:
|
||||
logger.warning(
|
||||
'Some files might not be downloaded due to name conflicts as filenames are'
|
||||
' not guaranteed to be be unique without {POSTID}')
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _format_for_windows(input_string: str) -> str:
|
||||
invalid_characters = r'<>:"\/|?*'
|
||||
for char in invalid_characters:
|
||||
input_string = input_string.replace(char, '')
|
||||
input_string = FileNameFormatter._strip_emojis(input_string)
|
||||
return input_string
|
||||
|
||||
@staticmethod
|
||||
def _strip_emojis(input_string: str) -> str:
|
||||
result = input_string.encode('ascii', errors='ignore').decode('utf-8')
|
||||
return result
|
||||
107
bdfr/oauth2.py
Normal file
107
bdfr/oauth2.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import configparser
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import socket
|
||||
from pathlib import Path
|
||||
|
||||
import praw
|
||||
import requests
|
||||
|
||||
from bdfr.exceptions import BulkDownloaderException, RedditAuthenticationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuth2Authenticator:
|
||||
|
||||
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
|
||||
self._check_scopes(wanted_scopes)
|
||||
self.scopes = wanted_scopes
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
@staticmethod
|
||||
def _check_scopes(wanted_scopes: set[str]):
|
||||
response = requests.get('https://www.reddit.com/api/v1/scopes.json',
|
||||
headers={'User-Agent': 'fetch-scopes test'})
|
||||
known_scopes = [scope for scope, data in response.json().items()]
|
||||
known_scopes.append('*')
|
||||
for scope in wanted_scopes:
|
||||
if scope not in known_scopes:
|
||||
raise BulkDownloaderException(f'Scope {scope} is not known to reddit')
|
||||
|
||||
@staticmethod
|
||||
def split_scopes(scopes: str) -> set[str]:
|
||||
scopes = re.split(r'[,: ]+', scopes)
|
||||
return set(scopes)
|
||||
|
||||
def retrieve_new_token(self) -> str:
|
||||
reddit = praw.Reddit(
|
||||
redirect_uri='http://localhost:7634',
|
||||
user_agent='obtain_refresh_token for BDFR',
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret)
|
||||
state = str(random.randint(0, 65000))
|
||||
url = reddit.auth.url(self.scopes, state, 'permanent')
|
||||
logger.warning('Authentication action required before the program can proceed')
|
||||
logger.warning(f'Authenticate at {url}')
|
||||
|
||||
client = self.receive_connection()
|
||||
data = client.recv(1024).decode('utf-8')
|
||||
param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&')
|
||||
params = {key: value for (key, value) in [token.split('=') for token in param_tokens]}
|
||||
|
||||
if state != params['state']:
|
||||
self.send_message(client)
|
||||
raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
|
||||
elif 'error' in params:
|
||||
self.send_message(client)
|
||||
raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
|
||||
|
||||
self.send_message(client, "<script>alert('You can go back to terminal window now.')</script>")
|
||||
refresh_token = reddit.auth.authorize(params["code"])
|
||||
return refresh_token
|
||||
|
||||
@staticmethod
|
||||
def receive_connection() -> socket.socket:
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(('localhost', 7634))
|
||||
logger.log(9, 'Server listening on localhost:7634')
|
||||
|
||||
server.listen(1)
|
||||
client = server.accept()[0]
|
||||
server.close()
|
||||
logger.log(9, 'Server closed')
|
||||
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def send_message(client: socket.socket, message: str):
|
||||
client.send(f'HTTP/1.1 200 OK\r\n\r\n{message}'.encode('utf-8'))
|
||||
client.close()
|
||||
|
||||
|
||||
class OAuth2TokenManager(praw.reddit.BaseTokenManager):
|
||||
def __init__(self, config: configparser.ConfigParser, config_location: Path):
|
||||
super(OAuth2TokenManager, self).__init__()
|
||||
self.config = config
|
||||
self.config_location = config_location
|
||||
|
||||
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
|
||||
if authorizer.refresh_token is None:
|
||||
if self.config.has_option('DEFAULT', 'user_token'):
|
||||
authorizer.refresh_token = self.config.get('DEFAULT', 'user_token')
|
||||
logger.log(9, 'Loaded OAuth2 token for authoriser')
|
||||
else:
|
||||
raise RedditAuthenticationError('No auth token loaded in configuration')
|
||||
|
||||
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
|
||||
self.config.set('DEFAULT', 'user_token', authorizer.refresh_token)
|
||||
with open(self.config_location, 'w') as file:
|
||||
self.config.write(file, True)
|
||||
logger.log(9, f'Written OAuth2 token from authoriser to {self.config_location}')
|
||||
81
bdfr/resource.py
Normal file
81
bdfr/resource.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import _hashlib
|
||||
import requests
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import BulkDownloaderException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Resource:
|
||||
def __init__(self, source_submission: Submission, url: str, extension: str = None):
|
||||
self.source_submission = source_submission
|
||||
self.content: Optional[bytes] = None
|
||||
self.url = url
|
||||
self.hash: Optional[_hashlib.HASH] = None
|
||||
self.extension = extension
|
||||
if not self.extension:
|
||||
self.extension = self._determine_extension()
|
||||
|
||||
@staticmethod
|
||||
def retry_download(url: str, max_wait_time: int) -> Optional[bytes]:
|
||||
wait_time = 60
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return response.content
|
||||
elif response.status_code in (
|
||||
301,
|
||||
400,
|
||||
401,
|
||||
403,
|
||||
404,
|
||||
407,
|
||||
410,
|
||||
500,
|
||||
501,
|
||||
502,
|
||||
):
|
||||
raise BulkDownloaderException(
|
||||
f'Unrecoverable error requesting resource: HTTP Code {response.status_code}')
|
||||
else:
|
||||
raise requests.exceptions.ConnectionError(f'Response code {response.status_code}')
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.warning(f'Error occured downloading from {url}, waiting {wait_time} seconds: {e}')
|
||||
time.sleep(wait_time)
|
||||
if wait_time < max_wait_time:
|
||||
return Resource.retry_download(url, max_wait_time)
|
||||
else:
|
||||
logger.error(f'Max wait time exceeded for resource at url {url}')
|
||||
raise
|
||||
|
||||
def download(self, max_wait_time: int):
|
||||
if not self.content:
|
||||
try:
|
||||
content = self.retry_download(self.url, max_wait_time)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise BulkDownloaderException(f'Could not download resource: {e}')
|
||||
except BulkDownloaderException:
|
||||
raise
|
||||
if content:
|
||||
self.content = content
|
||||
if not self.hash and self.content:
|
||||
self.create_hash()
|
||||
|
||||
def create_hash(self):
|
||||
self.hash = hashlib.md5(self.content)
|
||||
|
||||
def _determine_extension(self) -> str:
|
||||
extension_pattern = re.compile(r'.*(\..{3,5})(?:\?.*)?$')
|
||||
match = re.search(extension_pattern, self.url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
9
bdfr/site_authenticator.py
Normal file
9
bdfr/site_authenticator.py
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import configparser
|
||||
|
||||
|
||||
class SiteAuthenticator:
|
||||
def __init__(self, cfg: configparser.ConfigParser):
|
||||
self.imgur_authentication = None
|
||||
0
bdfr/site_downloaders/__init__.py
Normal file
0
bdfr/site_downloaders/__init__.py
Normal file
33
bdfr/site_downloaders/base_downloader.py
Normal file
33
bdfr/site_downloaders/base_downloader.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import ResourceNotFound
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseDownloader(ABC):
|
||||
def __init__(self, post: Submission, typical_extension: Optional[str] = None):
|
||||
self.post = post
|
||||
self.typical_extension = typical_extension
|
||||
|
||||
@abstractmethod
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
"""Return list of all un-downloaded Resources from submission"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def retrieve_url(url: str, cookies: dict = None, headers: dict = None) -> requests.Response:
|
||||
res = requests.get(url, cookies=cookies, headers=headers)
|
||||
if res.status_code != 200:
|
||||
raise ResourceNotFound(f'Server responded with {res.status_code} to {url}')
|
||||
return res
|
||||
17
bdfr/site_downloaders/direct.py
Normal file
17
bdfr/site_downloaders/direct.py
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
|
||||
|
||||
class Direct(BaseDownloader):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
return [Resource(self.post, self.post.url)]
|
||||
50
bdfr/site_downloaders/download_factory.py
Normal file
50
bdfr/site_downloaders/download_factory.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from bdfr.exceptions import NotADownloadableLinkError
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
from bdfr.site_downloaders.direct import Direct
|
||||
from bdfr.site_downloaders.erome import Erome
|
||||
from bdfr.site_downloaders.gallery import Gallery
|
||||
from bdfr.site_downloaders.gfycat import Gfycat
|
||||
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
|
||||
from bdfr.site_downloaders.imgur import Imgur
|
||||
from bdfr.site_downloaders.redgifs import Redgifs
|
||||
from bdfr.site_downloaders.self_post import SelfPost
|
||||
from bdfr.site_downloaders.vreddit import VReddit
|
||||
from bdfr.site_downloaders.youtube import Youtube
|
||||
|
||||
|
||||
class DownloadFactory:
|
||||
@staticmethod
|
||||
def pull_lever(url: str) -> Type[BaseDownloader]:
|
||||
url_beginning = r'\s*(https?://(www\.)?)'
|
||||
if re.match(url_beginning + r'(i\.)?imgur.*\.gifv$', url):
|
||||
return Imgur
|
||||
elif re.match(url_beginning + r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', url):
|
||||
return Direct
|
||||
elif re.match(url_beginning + r'erome\.com.*', url):
|
||||
return Erome
|
||||
elif re.match(url_beginning + r'reddit\.com/gallery/.*', url):
|
||||
return Gallery
|
||||
elif re.match(url_beginning + r'gfycat\.', url):
|
||||
return Gfycat
|
||||
elif re.match(url_beginning + r'gifdeliverynetwork', url):
|
||||
return GifDeliveryNetwork
|
||||
elif re.match(url_beginning + r'(m\.)?imgur.*', url):
|
||||
return Imgur
|
||||
elif re.match(url_beginning + r'redgifs.com', url):
|
||||
return Redgifs
|
||||
elif re.match(url_beginning + r'reddit\.com/r/', url):
|
||||
return SelfPost
|
||||
elif re.match(url_beginning + r'v\.redd\.it', url):
|
||||
return VReddit
|
||||
elif re.match(url_beginning + r'(m\.)?youtu\.?be', url):
|
||||
return Youtube
|
||||
elif re.match(url_beginning + r'i\.redd\.it.*', url):
|
||||
return Direct
|
||||
else:
|
||||
raise NotADownloadableLinkError(f'No downloader module exists for url {url}')
|
||||
45
bdfr/site_downloaders/erome.py
Normal file
45
bdfr/site_downloaders/erome.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import bs4
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import SiteDownloaderError
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Erome(BaseDownloader):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
links = self._get_links(self.post.url)
|
||||
|
||||
if not links:
|
||||
raise SiteDownloaderError('Erome parser could not find any links')
|
||||
|
||||
out = []
|
||||
for link in links:
|
||||
if not re.match(r'https?://.*', link):
|
||||
link = 'https://' + link
|
||||
out.append(Resource(self.post, link))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _get_links(url: str) -> set[str]:
|
||||
page = Erome.retrieve_url(url)
|
||||
soup = bs4.BeautifulSoup(page.text, 'html.parser')
|
||||
front_images = soup.find_all('img', attrs={'class': 'lasyload'})
|
||||
out = [im.get('data-src') for im in front_images]
|
||||
|
||||
videos = soup.find_all('source')
|
||||
out.extend([vid.get('src') for vid in videos])
|
||||
|
||||
return set(out)
|
||||
40
bdfr/site_downloaders/gallery.py
Normal file
40
bdfr/site_downloaders/gallery.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import bs4
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import SiteDownloaderError
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Gallery(BaseDownloader):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
image_urls = self._get_links(self.post.url)
|
||||
if not image_urls:
|
||||
raise SiteDownloaderError('No images found in Reddit gallery')
|
||||
return [Resource(self.post, url) for url in image_urls]
|
||||
|
||||
@staticmethod
|
||||
def _get_links(url: str) -> list[str]:
|
||||
resource_headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)'
|
||||
' Chrome/67.0.3396.87 Safari/537.36 OPR/54.0.2952.64',
|
||||
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8',
|
||||
}
|
||||
page = Gallery.retrieve_url(url, headers=resource_headers)
|
||||
soup = bs4.BeautifulSoup(page.text, 'html.parser')
|
||||
|
||||
links = soup.findAll('a', attrs={'target': '_blank', 'href': re.compile(r'https://preview\.redd\.it.*')})
|
||||
links = [link.get('href') for link in links]
|
||||
return links
|
||||
41
bdfr/site_downloaders/gfycat.py
Normal file
41
bdfr/site_downloaders/gfycat.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import SiteDownloaderError
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
|
||||
|
||||
|
||||
class Gfycat(GifDeliveryNetwork):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
return super().find_resources(authenticator)
|
||||
|
||||
@staticmethod
|
||||
def _get_link(url: str) -> str:
|
||||
gfycat_id = re.match(r'.*/(.*?)/?$', url).group(1)
|
||||
url = 'https://gfycat.com/' + gfycat_id
|
||||
|
||||
response = Gfycat.retrieve_url(url)
|
||||
if 'gifdeliverynetwork' in response.url:
|
||||
return GifDeliveryNetwork._get_link(url)
|
||||
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
content = soup.find('script', attrs={'data-react-helmet': 'true', 'type': 'application/ld+json'})
|
||||
|
||||
try:
|
||||
out = json.loads(content.contents[0])['video']['contentUrl']
|
||||
except (IndexError, KeyError) as e:
|
||||
raise SiteDownloaderError(f'Failed to download Gfycat link {url}: {e}')
|
||||
except json.JSONDecodeError as e:
|
||||
raise SiteDownloaderError(f'Did not receive valid JSON data: {e}')
|
||||
return out
|
||||
36
bdfr/site_downloaders/gif_delivery_network.py
Normal file
36
bdfr/site_downloaders/gif_delivery_network.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import NotADownloadableLinkError, SiteDownloaderError
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
|
||||
|
||||
class GifDeliveryNetwork(BaseDownloader):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
media_url = self._get_link(self.post.url)
|
||||
return [Resource(self.post, media_url, '.mp4')]
|
||||
|
||||
@staticmethod
|
||||
def _get_link(url: str) -> str:
|
||||
page = GifDeliveryNetwork.retrieve_url(url)
|
||||
|
||||
soup = BeautifulSoup(page.text, 'html.parser')
|
||||
content = soup.find('source', attrs={'id': 'mp4Source', 'type': 'video/mp4'})
|
||||
|
||||
try:
|
||||
out = content['src']
|
||||
if not out:
|
||||
raise KeyError
|
||||
except KeyError:
|
||||
raise SiteDownloaderError('Could not find source link')
|
||||
|
||||
return out
|
||||
79
bdfr/site_downloaders/imgur.py
Normal file
79
bdfr/site_downloaders/imgur.py
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import bs4
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import NotADownloadableLinkError, SiteDownloaderError
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
|
||||
|
||||
class Imgur(BaseDownloader):
|
||||
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
self.raw_data = {}
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
self.raw_data = self._get_data(self.post.url)
|
||||
|
||||
out = []
|
||||
if 'album_images' in self.raw_data:
|
||||
images = self.raw_data['album_images']
|
||||
for image in images['images']:
|
||||
out.append(self._download_image(image))
|
||||
else:
|
||||
out.append(self._download_image(self.raw_data))
|
||||
return out
|
||||
|
||||
def _download_image(self, image: dict) -> Resource:
|
||||
image_url = 'https://i.imgur.com/' + image['hash'] + self._validate_extension(image['ext'])
|
||||
return Resource(self.post, image_url)
|
||||
|
||||
@staticmethod
|
||||
def _get_data(link: str) -> dict:
|
||||
if re.match(r'.*\.gifv$', link):
|
||||
link = link.replace('i.imgur', 'imgur')
|
||||
link = link.rstrip('.gifv')
|
||||
|
||||
res = Imgur.retrieve_url(link, cookies={'over18': '1', 'postpagebeta': '0'})
|
||||
|
||||
soup = bs4.BeautifulSoup(res.text, 'html.parser')
|
||||
scripts = soup.find_all('script', attrs={'type': 'text/javascript'})
|
||||
scripts = [script.string.replace('\n', '') for script in scripts if script.string]
|
||||
|
||||
script_regex = re.compile(r'\s*\(function\(widgetFactory\)\s*{\s*widgetFactory\.mergeConfig\(\'gallery\'')
|
||||
chosen_script = list(filter(lambda s: re.search(script_regex, s), scripts))
|
||||
if len(chosen_script) != 1:
|
||||
raise SiteDownloaderError(f'Could not read page source from {link}')
|
||||
|
||||
chosen_script = chosen_script[0]
|
||||
|
||||
outer_regex = re.compile(r'widgetFactory\.mergeConfig\(\'gallery\', ({.*})\);')
|
||||
inner_regex = re.compile(r'image\s*:(.*),\s*group')
|
||||
try:
|
||||
image_dict = re.search(outer_regex, chosen_script).group(1)
|
||||
image_dict = re.search(inner_regex, image_dict).group(1)
|
||||
except AttributeError:
|
||||
raise SiteDownloaderError(f'Could not find image dictionary in page source')
|
||||
|
||||
try:
|
||||
image_dict = json.loads(image_dict)
|
||||
except json.JSONDecodeError as e:
|
||||
raise SiteDownloaderError(f'Could not parse received dict as JSON: {e}')
|
||||
|
||||
return image_dict
|
||||
|
||||
@staticmethod
|
||||
def _validate_extension(extension_suffix: str) -> str:
|
||||
possible_extensions = ('.jpg', '.png', '.mp4', '.gif')
|
||||
selection = [ext for ext in possible_extensions if ext == extension_suffix]
|
||||
if len(selection) == 1:
|
||||
return selection[0]
|
||||
else:
|
||||
raise SiteDownloaderError(f'"{extension_suffix}" is not recognized as a valid extension for Imgur')
|
||||
52
bdfr/site_downloaders/redgifs.py
Normal file
52
bdfr/site_downloaders/redgifs.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import NotADownloadableLinkError, SiteDownloaderError
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
|
||||
|
||||
|
||||
class Redgifs(GifDeliveryNetwork):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
return super().find_resources(authenticator)
|
||||
|
||||
@staticmethod
|
||||
def _get_link(url: str) -> str:
|
||||
try:
|
||||
redgif_id = re.match(r'.*/(.*?)/?$', url).group(1)
|
||||
except AttributeError:
|
||||
raise SiteDownloaderError(f'Could not extract Redgifs ID from {url}')
|
||||
|
||||
url = 'https://redgifs.com/watch/' + redgif_id
|
||||
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)'
|
||||
' Chrome/67.0.3396.87 Safari/537.36 OPR/54.0.2952.64',
|
||||
}
|
||||
|
||||
page = Redgifs.retrieve_url(url, headers=headers)
|
||||
|
||||
soup = BeautifulSoup(page.text, 'html.parser')
|
||||
content = soup.find('script', attrs={'data-react-helmet': 'true', 'type': 'application/ld+json'})
|
||||
|
||||
if content is None:
|
||||
raise SiteDownloaderError('Could not read the page source')
|
||||
|
||||
try:
|
||||
out = json.loads(content.contents[0])['video']['contentUrl']
|
||||
except (IndexError, KeyError):
|
||||
raise SiteDownloaderError('Failed to find JSON data in page')
|
||||
except json.JSONDecodeError as e:
|
||||
raise SiteDownloaderError(f'Received data was not valid JSON: {e}')
|
||||
|
||||
return out
|
||||
43
bdfr/site_downloaders/self_post.py
Normal file
43
bdfr/site_downloaders/self_post.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SelfPost(BaseDownloader):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
out = Resource(self.post, self.post.url, '.txt')
|
||||
out.content = self.export_to_string().encode('utf-8')
|
||||
out.create_hash()
|
||||
return [out]
|
||||
|
||||
def export_to_string(self) -> str:
|
||||
"""Self posts are formatted here"""
|
||||
content = ("## ["
|
||||
+ self.post.fullname
|
||||
+ "]("
|
||||
+ self.post.url
|
||||
+ ")\n"
|
||||
+ self.post.selftext
|
||||
+ "\n\n---\n\n"
|
||||
+ "submitted to [r/"
|
||||
+ self.post.subreddit.title
|
||||
+ "](https://www.reddit.com/r/"
|
||||
+ self.post.subreddit.title
|
||||
+ ") by [u/"
|
||||
+ (self.post.author.name if self.post.author else "DELETED")
|
||||
+ "](https://www.reddit.com/user/"
|
||||
+ (self.post.author.name if self.post.author else "DELETED")
|
||||
+ ")")
|
||||
return content
|
||||
21
bdfr/site_downloaders/vreddit.py
Normal file
21
bdfr/site_downloaders/vreddit.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.youtube import Youtube
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VReddit(Youtube):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
out = super()._download_video({})
|
||||
return [out]
|
||||
50
bdfr/site_downloaders/youtube.py
Normal file
50
bdfr/site_downloaders/youtube.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import youtube_dl
|
||||
from praw.models import Submission
|
||||
|
||||
from bdfr.exceptions import SiteDownloaderError
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Youtube(BaseDownloader):
|
||||
def __init__(self, post: Submission):
|
||||
super().__init__(post)
|
||||
|
||||
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
|
||||
ytdl_options = {
|
||||
'format': 'best',
|
||||
'playlistend': 1,
|
||||
'nooverwrites': True,
|
||||
}
|
||||
out = self._download_video(ytdl_options)
|
||||
return [out]
|
||||
|
||||
def _download_video(self, ytdl_options: dict) -> Resource:
|
||||
ytdl_options['quiet'] = True
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
download_path = Path(temp_dir).resolve()
|
||||
ytdl_options['outtmpl'] = str(download_path) + '/' + 'test.%(ext)s'
|
||||
try:
|
||||
with youtube_dl.YoutubeDL(ytdl_options) as ydl:
|
||||
ydl.download([self.post.url])
|
||||
except youtube_dl.DownloadError as e:
|
||||
raise SiteDownloaderError(f'Youtube download failed: {e}')
|
||||
|
||||
downloaded_file = list(download_path.iterdir())[0]
|
||||
extension = downloaded_file.suffix
|
||||
with open(downloaded_file, 'rb') as file:
|
||||
content = file.read()
|
||||
out = Resource(self.post, self.post.url, extension)
|
||||
out.content = content
|
||||
out.create_hash()
|
||||
return out
|
||||
0
bdfr/tests/__init__.py
Normal file
0
bdfr/tests/__init__.py
Normal file
2
bdfr/tests/archive_entry/__init__.py
Normal file
2
bdfr/tests/archive_entry/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
38
bdfr/tests/archive_entry/test_comment_archive_entry.py
Normal file
38
bdfr/tests/archive_entry/test_comment_archive_entry.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bdfr.archive_entry.comment_archive_entry import CommentArchiveEntry
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_comment_id', 'expected_dict'), (
|
||||
('gstd4hk', {
|
||||
'author': 'james_pic',
|
||||
'subreddit': 'Python',
|
||||
'submission': 'mgi4op',
|
||||
'submission_title': '76% Faster CPython',
|
||||
}),
|
||||
))
|
||||
def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_instance: praw.Reddit):
|
||||
comment = reddit_instance.comment(id=test_comment_id)
|
||||
test_entry = CommentArchiveEntry(comment)
|
||||
result = test_entry.compile()
|
||||
assert all([result.get(key) == expected_dict[key] for key in expected_dict.keys()])
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_comment_id', 'expected_min_comments'), (
|
||||
('gstd4hk', 4),
|
||||
('gsvyste', 3),
|
||||
('gsxnvvb', 5),
|
||||
))
|
||||
def test_get_comment_replies(test_comment_id: str, expected_min_comments: int, reddit_instance: praw.Reddit):
|
||||
comment = reddit_instance.comment(id=test_comment_id)
|
||||
test_entry = CommentArchiveEntry(comment)
|
||||
result = test_entry.compile()
|
||||
assert len(result.get('replies')) >= expected_min_comments
|
||||
36
bdfr/tests/archive_entry/test_submission_archive_entry.py
Normal file
36
bdfr/tests/archive_entry/test_submission_archive_entry.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'min_comments'), (
|
||||
('m3reby', 27),
|
||||
))
|
||||
def test_get_comments(test_submission_id: str, min_comments: int, reddit_instance: praw.Reddit):
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
test_archive_entry = SubmissionArchiveEntry(test_submission)
|
||||
results = test_archive_entry._get_comments()
|
||||
assert len(results) >= min_comments
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'expected_dict'), (
|
||||
('m3reby', {
|
||||
'author': 'sinjen-tos',
|
||||
'id': 'm3reby',
|
||||
'link_flair_text': 'image',
|
||||
}),
|
||||
('m3kua3', {'author': 'DELETED'}),
|
||||
))
|
||||
def test_get_post_details(test_submission_id: str, expected_dict: dict, reddit_instance: praw.Reddit):
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
test_archive_entry = SubmissionArchiveEntry(test_submission)
|
||||
test_archive_entry._get_post_details()
|
||||
assert all([test_archive_entry.post_details.get(key) == expected_dict[key] for key in expected_dict.keys()])
|
||||
34
bdfr/tests/conftest.py
Normal file
34
bdfr/tests/conftest.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import configparser
|
||||
import socket
|
||||
from pathlib import Path
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bdfr.oauth2 import OAuth2TokenManager
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def reddit_instance():
|
||||
rd = praw.Reddit(client_id='U-6gk4ZCh3IeNQ', client_secret='7CZHY6AmKweZME5s50SfDGylaPg', user_agent='test')
|
||||
return rd
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def authenticated_reddit_instance():
|
||||
test_config_path = Path('test_config.cfg')
|
||||
if not test_config_path.exists():
|
||||
pytest.skip('Refresh token must be provided to authenticate with OAuth2')
|
||||
cfg_parser = configparser.ConfigParser()
|
||||
cfg_parser.read(test_config_path)
|
||||
if not cfg_parser.has_option('DEFAULT', 'user_token'):
|
||||
pytest.skip('Refresh token must be provided to authenticate with OAuth2')
|
||||
token_manager = OAuth2TokenManager(cfg_parser, test_config_path)
|
||||
reddit_instance = praw.Reddit(client_id=cfg_parser.get('DEFAULT', 'client_id'),
|
||||
client_secret=cfg_parser.get('DEFAULT', 'client_secret'),
|
||||
user_agent=socket.gethostname(),
|
||||
token_manager=token_manager)
|
||||
return reddit_instance
|
||||
0
bdfr/tests/site_downloaders/__init__.py
Normal file
0
bdfr/tests/site_downloaders/__init__.py
Normal file
25
bdfr/tests/site_downloaders/test_direct.py
Normal file
25
bdfr/tests/site_downloaders/test_direct.py
Normal file
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.direct import Direct
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
|
||||
('https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4', '48f9bd4dbec1556d7838885612b13b39'),
|
||||
('https://giant.gfycat.com/DazzlingSilkyIguana.mp4', '808941b48fc1e28713d36dd7ed9dc648'),
|
||||
))
|
||||
def test_download_resource(test_url: str, expected_hash: str):
|
||||
mock_submission = Mock()
|
||||
mock_submission.url = test_url
|
||||
test_site = Direct(mock_submission)
|
||||
resources = test_site.find_resources()
|
||||
assert len(resources) == 1
|
||||
assert isinstance(resources[0], Resource)
|
||||
resources[0].download(120)
|
||||
assert resources[0].hash.hexdigest() == expected_hash
|
||||
60
bdfr/tests/site_downloaders/test_download_factory.py
Normal file
60
bdfr/tests/site_downloaders/test_download_factory.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bdfr.exceptions import NotADownloadableLinkError
|
||||
from bdfr.site_downloaders.base_downloader import BaseDownloader
|
||||
from bdfr.site_downloaders.direct import Direct
|
||||
from bdfr.site_downloaders.download_factory import DownloadFactory
|
||||
from bdfr.site_downloaders.erome import Erome
|
||||
from bdfr.site_downloaders.gallery import Gallery
|
||||
from bdfr.site_downloaders.gfycat import Gfycat
|
||||
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
|
||||
from bdfr.site_downloaders.imgur import Imgur
|
||||
from bdfr.site_downloaders.redgifs import Redgifs
|
||||
from bdfr.site_downloaders.self_post import SelfPost
|
||||
from bdfr.site_downloaders.vreddit import VReddit
|
||||
from bdfr.site_downloaders.youtube import Youtube
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_submission_url', 'expected_class'), (
|
||||
('https://v.redd.it/9z1dnk3xr5k61', VReddit),
|
||||
('https://www.reddit.com/r/TwoXChromosomes/comments/lu29zn/i_refuse_to_live_my_life'
|
||||
'_in_anything_but_comfort/', SelfPost),
|
||||
('https://i.imgur.com/bZx1SJQ.jpg', Direct),
|
||||
('https://i.redd.it/affyv0axd5k61.png', Direct),
|
||||
('https://imgur.com/3ls94yv.jpeg', Direct),
|
||||
('https://i.imgur.com/BuzvZwb.gifv', Imgur),
|
||||
('https://imgur.com/BuzvZwb.gifv', Imgur),
|
||||
('https://i.imgur.com/6fNdLst.gif', Direct),
|
||||
('https://imgur.com/a/MkxAzeg', Imgur),
|
||||
('https://www.reddit.com/gallery/lu93m7', Gallery),
|
||||
('https://gfycat.com/concretecheerfulfinwhale', Gfycat),
|
||||
('https://www.erome.com/a/NWGw0F09', Erome),
|
||||
('https://youtube.com/watch?v=Gv8Wz74FjVA', Youtube),
|
||||
('https://redgifs.com/watch/courageousimpeccablecanvasback', Redgifs),
|
||||
('https://www.gifdeliverynetwork.com/repulsivefinishedandalusianhorse', GifDeliveryNetwork),
|
||||
('https://youtu.be/DevfjHOhuFc', Youtube),
|
||||
('https://m.youtube.com/watch?v=kr-FeojxzUM', Youtube),
|
||||
('https://i.imgur.com/3SKrQfK.jpg?1', Direct),
|
||||
('https://dynasty-scans.com/system/images_images/000/017/819/original/80215103_p0.png?1612232781', Direct),
|
||||
('https://m.imgur.com/a/py3RW0j', Imgur),
|
||||
))
|
||||
def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownloader, reddit_instance: praw.Reddit):
|
||||
result = DownloadFactory.pull_lever(test_submission_url)
|
||||
assert result is expected_class
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_url', (
|
||||
'random.com',
|
||||
'bad',
|
||||
'https://www.google.com/',
|
||||
'https://www.google.com',
|
||||
'https://www.google.com/test',
|
||||
'https://www.google.com/test/',
|
||||
))
|
||||
def test_factory_lever_bad(test_url: str):
|
||||
with pytest.raises(NotADownloadableLinkError):
|
||||
DownloadFactory.pull_lever(test_url)
|
||||
57
bdfr/tests/site_downloaders/test_erome.py
Normal file
57
bdfr/tests/site_downloaders/test_erome.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.site_downloaders.erome import Erome
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_urls'), (
|
||||
('https://www.erome.com/a/vqtPuLXh', (
|
||||
'https://s11.erome.com/365/vqtPuLXh/KH2qBT99_480p.mp4',
|
||||
)),
|
||||
('https://www.erome.com/a/ORhX0FZz', (
|
||||
'https://s4.erome.com/355/ORhX0FZz/9IYQocM9_480p.mp4',
|
||||
'https://s4.erome.com/355/ORhX0FZz/9eEDc8xm_480p.mp4',
|
||||
'https://s4.erome.com/355/ORhX0FZz/EvApC7Rp_480p.mp4',
|
||||
'https://s4.erome.com/355/ORhX0FZz/LruobtMs_480p.mp4',
|
||||
'https://s4.erome.com/355/ORhX0FZz/TJNmSUU5_480p.mp4',
|
||||
'https://s4.erome.com/355/ORhX0FZz/X11Skh6Z_480p.mp4',
|
||||
'https://s4.erome.com/355/ORhX0FZz/bjlTkpn7_480p.mp4'
|
||||
)),
|
||||
))
|
||||
def test_get_link(test_url: str, expected_urls: tuple[str]):
|
||||
result = Erome. _get_links(test_url)
|
||||
assert set(result) == set(expected_urls)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(('test_url', 'expected_hashes'), (
|
||||
('https://www.erome.com/a/vqtPuLXh', {
|
||||
'5da2a8d60d87bed279431fdec8e7d72f'
|
||||
}),
|
||||
('https://www.erome.com/i/ItASD33e', {
|
||||
'b0d73fedc9ce6995c2f2c4fdb6f11eff'
|
||||
}),
|
||||
('https://www.erome.com/a/lGrcFxmb', {
|
||||
'0e98f9f527a911dcedde4f846bb5b69f',
|
||||
'25696ae364750a5303fc7d7dc78b35c1',
|
||||
'63775689f438bd393cde7db6d46187de',
|
||||
'a1abf398cfd4ef9cfaf093ceb10c746a',
|
||||
'bd9e1a4ea5ef0d6ba47fb90e337c2d14'
|
||||
}),
|
||||
))
|
||||
def test_download_resource(test_url: str, expected_hashes: tuple[str]):
|
||||
# Can't compare hashes for this test, Erome doesn't return the exact same file from request to request so the hash
|
||||
# will change back and forth randomly
|
||||
mock_submission = MagicMock()
|
||||
mock_submission.url = test_url
|
||||
test_site = Erome(mock_submission)
|
||||
resources = test_site.find_resources()
|
||||
[res.download(120) for res in resources]
|
||||
resource_hashes = [res.hash.hexdigest() for res in resources]
|
||||
assert len(resource_hashes) == len(expected_hashes)
|
||||
60
bdfr/tests/site_downloaders/test_gallery.py
Normal file
60
bdfr/tests/site_downloaders/test_gallery.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bdfr.site_downloaders.gallery import Gallery
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||
('https://www.reddit.com/gallery/m6lvrh', {
|
||||
'https://preview.redd.it/18nzv9ch0hn61.jpg?width=4160&'
|
||||
'format=pjpg&auto=webp&s=470a825b9c364e0eace0036882dcff926f821de8',
|
||||
'https://preview.redd.it/jqkizcch0hn61.jpg?width=4160&'
|
||||
'format=pjpg&auto=webp&s=ae4f552a18066bb6727676b14f2451c5feecf805',
|
||||
'https://preview.redd.it/k0fnqzbh0hn61.jpg?width=4160&'
|
||||
'format=pjpg&auto=webp&s=c6a10fececdc33983487c16ad02219fd3fc6cd76',
|
||||
'https://preview.redd.it/m3gamzbh0hn61.jpg?width=4160&'
|
||||
'format=pjpg&auto=webp&s=0dd90f324711851953e24873290b7f29ec73c444'
|
||||
}),
|
||||
('https://www.reddit.com/gallery/ljyy27', {
|
||||
'https://preview.redd.it/04vxj25uqih61.png?width=92&'
|
||||
'format=png&auto=webp&s=6513f3a5c5128ee7680d402cab5ea4fb2bbeead4',
|
||||
'https://preview.redd.it/0fnx83kpqih61.png?width=241&'
|
||||
'format=png&auto=webp&s=655e9deb6f499c9ba1476eaff56787a697e6255a',
|
||||
'https://preview.redd.it/7zkmr1wqqih61.png?width=237&'
|
||||
'format=png&auto=webp&s=19de214e634cbcad9959f19570c616e29be0c0b0',
|
||||
'https://preview.redd.it/u37k5gxrqih61.png?width=443&'
|
||||
'format=png&auto=webp&s=e74dae31841fe4a2545ffd794d3b25b9ff0eb862'
|
||||
}),
|
||||
))
|
||||
def test_gallery_get_links(test_url: str, expected: set[str]):
|
||||
results = Gallery._get_links(test_url)
|
||||
assert set(results) == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'expected_hashes'), (
|
||||
('m6lvrh', {
|
||||
'6c8a892ae8066cbe119218bcaac731e1',
|
||||
'93ce177f8cb7994906795f4615114d13',
|
||||
'9a293adf19354f14582608cf22124574',
|
||||
'b73e2c3daee02f99404644ea02f1ae65'
|
||||
}),
|
||||
('ljyy27', {
|
||||
'1bc38bed88f9c4770e22a37122d5c941',
|
||||
'2539a92b78f3968a069df2dffe2279f9',
|
||||
'37dea50281c219b905e46edeefc1a18d',
|
||||
'ec4924cf40549728dcf53dd40bc7a73c'
|
||||
}),
|
||||
))
|
||||
def test_gallery_download(test_submission_id: str, expected_hashes: set[str], reddit_instance: praw.Reddit):
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
gallery = Gallery(test_submission)
|
||||
results = gallery.find_resources()
|
||||
[res.download(120) for res in results]
|
||||
hashes = [res.hash.hexdigest() for res in results]
|
||||
assert set(hashes) == expected_hashes
|
||||
36
bdfr/tests/site_downloaders/test_gfycat.py
Normal file
36
bdfr/tests/site_downloaders/test_gfycat.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.gfycat import Gfycat
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_url'), (
|
||||
('https://gfycat.com/definitivecaninecrayfish', 'https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4'),
|
||||
('https://gfycat.com/dazzlingsilkyiguana', 'https://giant.gfycat.com/DazzlingSilkyIguana.mp4'),
|
||||
('https://gfycat.com/webbedimpurebutterfly', 'https://thumbs2.redgifs.com/WebbedImpureButterfly.mp4'),
|
||||
))
|
||||
def test_get_link(test_url: str, expected_url: str):
|
||||
result = Gfycat._get_link(test_url)
|
||||
assert result == expected_url
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
|
||||
('https://gfycat.com/definitivecaninecrayfish', '48f9bd4dbec1556d7838885612b13b39'),
|
||||
('https://gfycat.com/dazzlingsilkyiguana', '808941b48fc1e28713d36dd7ed9dc648'),
|
||||
))
|
||||
def test_download_resource(test_url: str, expected_hash: str):
|
||||
mock_submission = Mock()
|
||||
mock_submission.url = test_url
|
||||
test_site = Gfycat(mock_submission)
|
||||
resources = test_site.find_resources()
|
||||
assert len(resources) == 1
|
||||
assert isinstance(resources[0], Resource)
|
||||
resources[0].download(120)
|
||||
assert resources[0].hash.hexdigest() == expected_hash
|
||||
37
bdfr/tests/site_downloaders/test_gif_delivery_network.py
Normal file
37
bdfr/tests/site_downloaders/test_gif_delivery_network.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.gif_delivery_network import GifDeliveryNetwork
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||
('https://www.gifdeliverynetwork.com/regalshoddyhorsechestnutleafminer',
|
||||
'https://thumbs2.redgifs.com/RegalShoddyHorsechestnutleafminer.mp4'),
|
||||
('https://www.gifdeliverynetwork.com/maturenexthippopotamus',
|
||||
'https://thumbs2.redgifs.com/MatureNextHippopotamus.mp4'),
|
||||
))
|
||||
def test_get_link(test_url: str, expected: str):
|
||||
result = GifDeliveryNetwork._get_link(test_url)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
|
||||
('https://www.gifdeliverynetwork.com/maturenexthippopotamus', '9bec0a9e4163a43781368ed5d70471df'),
|
||||
('https://www.gifdeliverynetwork.com/regalshoddyhorsechestnutleafminer', '8afb4e2c090a87140230f2352bf8beba'),
|
||||
))
|
||||
def test_download_resource(test_url: str, expected_hash: str):
|
||||
mock_submission = Mock()
|
||||
mock_submission.url = test_url
|
||||
test_site = GifDeliveryNetwork(mock_submission)
|
||||
resources = test_site.find_resources()
|
||||
assert len(resources) == 1
|
||||
assert isinstance(resources[0], Resource)
|
||||
resources[0].download(120)
|
||||
assert resources[0].hash.hexdigest() == expected_hash
|
||||
135
bdfr/tests/site_downloaders/test_imgur.py
Normal file
135
bdfr/tests/site_downloaders/test_imgur.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.exceptions import SiteDownloaderError
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.imgur import Imgur
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_gen_dict', 'expected_image_dict'), (
|
||||
(
|
||||
'https://imgur.com/a/xWZsDDP',
|
||||
{'num_images': '1', 'id': 'xWZsDDP', 'hash': 'xWZsDDP'},
|
||||
[
|
||||
{'hash': 'ypa8YfS', 'title': '', 'ext': '.png', 'animated': False}
|
||||
]
|
||||
),
|
||||
(
|
||||
'https://imgur.com/gallery/IjJJdlC',
|
||||
{'num_images': 1, 'id': 384898055, 'hash': 'IjJJdlC'},
|
||||
[
|
||||
{'hash': 'CbbScDt',
|
||||
'description': 'watch when he gets it',
|
||||
'ext': '.gif',
|
||||
'animated': True,
|
||||
'has_sound': False
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
'https://imgur.com/a/dcc84Gt',
|
||||
{'num_images': '4', 'id': 'dcc84Gt', 'hash': 'dcc84Gt'},
|
||||
[
|
||||
{'hash': 'ylx0Kle', 'ext': '.jpg', 'title': ''},
|
||||
{'hash': 'TdYfKbK', 'ext': '.jpg', 'title': ''},
|
||||
{'hash': 'pCxGbe8', 'ext': '.jpg', 'title': ''},
|
||||
{'hash': 'TSAkikk', 'ext': '.jpg', 'title': ''},
|
||||
]
|
||||
),
|
||||
(
|
||||
'https://m.imgur.com/a/py3RW0j',
|
||||
{'num_images': '1', 'id': 'py3RW0j', 'hash': 'py3RW0j', },
|
||||
[
|
||||
{'hash': 'K24eQmK', 'has_sound': False, 'ext': '.jpg'}
|
||||
],
|
||||
),
|
||||
))
|
||||
def test_get_data_album(test_url: str, expected_gen_dict: dict, expected_image_dict: list[dict]):
|
||||
result = Imgur._get_data(test_url)
|
||||
assert all([result.get(key) == expected_gen_dict[key] for key in expected_gen_dict.keys()])
|
||||
|
||||
# Check if all the keys from the test dict are correct in at least one of the album entries
|
||||
assert any([all([image.get(key) == image_dict[key] for key in image_dict.keys()])
|
||||
for image_dict in expected_image_dict for image in result['album_images']['images']])
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_image_dict'), (
|
||||
(
|
||||
'https://i.imgur.com/dLk3FGY.gifv',
|
||||
{'hash': 'dLk3FGY', 'title': '', 'ext': '.mp4', 'animated': True}
|
||||
),
|
||||
(
|
||||
'https://imgur.com/BuzvZwb.gifv',
|
||||
{
|
||||
'hash': 'BuzvZwb',
|
||||
'title': '',
|
||||
'description': 'Akron Glass Works',
|
||||
'animated': True,
|
||||
'mimetype': 'video/mp4'
|
||||
},
|
||||
),
|
||||
))
|
||||
def test_get_data_gif(test_url: str, expected_image_dict: dict):
|
||||
result = Imgur._get_data(test_url)
|
||||
assert all([result.get(key) == expected_image_dict[key] for key in expected_image_dict.keys()])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_extension', (
|
||||
'.gif',
|
||||
'.png',
|
||||
'.jpg',
|
||||
'.mp4'
|
||||
))
|
||||
def test_imgur_extension_validation_good(test_extension: str):
|
||||
result = Imgur._validate_extension(test_extension)
|
||||
assert result == test_extension
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_extension', (
|
||||
'.jpeg',
|
||||
'bad',
|
||||
'.avi',
|
||||
'.test',
|
||||
'.flac',
|
||||
))
|
||||
def test_imgur_extension_validation_bad(test_extension: str):
|
||||
with pytest.raises(SiteDownloaderError):
|
||||
Imgur._validate_extension(test_extension)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_hashes'), (
|
||||
(
|
||||
'https://imgur.com/a/xWZsDDP',
|
||||
('f551d6e6b0fef2ce909767338612e31b',)
|
||||
),
|
||||
(
|
||||
'https://imgur.com/gallery/IjJJdlC',
|
||||
('7227d4312a9779b74302724a0cfa9081',),
|
||||
),
|
||||
(
|
||||
'https://imgur.com/a/dcc84Gt',
|
||||
(
|
||||
'cf1158e1de5c3c8993461383b96610cf',
|
||||
'28d6b791a2daef8aa363bf5a3198535d',
|
||||
'248ef8f2a6d03eeb2a80d0123dbaf9b6',
|
||||
'029c475ce01b58fdf1269d8771d33913',
|
||||
),
|
||||
),
|
||||
))
|
||||
def test_find_resources(test_url: str, expected_hashes: list[str]):
|
||||
mock_download = Mock()
|
||||
mock_download.url = test_url
|
||||
downloader = Imgur(mock_download)
|
||||
results = downloader.find_resources()
|
||||
assert all([isinstance(res, Resource) for res in results])
|
||||
[res.download(120) for res in results]
|
||||
hashes = set([res.hash.hexdigest() for res in results])
|
||||
assert len(results) == len(expected_hashes)
|
||||
assert hashes == set(expected_hashes)
|
||||
37
bdfr/tests/site_downloaders/test_redgifs.py
Normal file
37
bdfr/tests/site_downloaders/test_redgifs.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.redgifs import Redgifs
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||
('https://redgifs.com/watch/frighteningvictorioussalamander',
|
||||
'https://thumbs2.redgifs.com/FrighteningVictoriousSalamander.mp4'),
|
||||
('https://redgifs.com/watch/springgreendecisivetaruca',
|
||||
'https://thumbs2.redgifs.com/SpringgreenDecisiveTaruca.mp4'),
|
||||
))
|
||||
def test_get_link(test_url: str, expected: str):
|
||||
result = Redgifs._get_link(test_url)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
|
||||
('https://redgifs.com/watch/frighteningvictorioussalamander', '4007c35d9e1f4b67091b5f12cffda00a'),
|
||||
('https://redgifs.com/watch/springgreendecisivetaruca', '8dac487ac49a1f18cc1b4dabe23f0869'),
|
||||
))
|
||||
def test_download_resource(test_url: str, expected_hash: str):
|
||||
mock_submission = Mock()
|
||||
mock_submission.url = test_url
|
||||
test_site = Redgifs(mock_submission)
|
||||
resources = test_site.find_resources()
|
||||
assert len(resources) == 1
|
||||
assert isinstance(resources[0], Resource)
|
||||
resources[0].download(120)
|
||||
assert resources[0].hash.hexdigest() == expected_hash
|
||||
24
bdfr/tests/site_downloaders/test_self_post.py
Normal file
24
bdfr/tests/site_downloaders/test_self_post.py
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.self_post import SelfPost
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'expected_hash'), (
|
||||
('ltmivt', '7d2c9e4e989e5cf2dca2e55a06b1c4f6'),
|
||||
('ltoaan', '221606386b614d6780c2585a59bd333f'),
|
||||
('d3sc8o', 'c1ff2b6bd3f6b91381dcd18dfc4ca35f'),
|
||||
))
|
||||
def test_find_resource(test_submission_id: str, expected_hash: str, reddit_instance: praw.Reddit):
|
||||
submission = reddit_instance.submission(id=test_submission_id)
|
||||
downloader = SelfPost(submission)
|
||||
results = downloader.find_resources()
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], Resource)
|
||||
assert results[0].hash.hexdigest() == expected_hash
|
||||
23
bdfr/tests/site_downloaders/test_vreddit.py
Normal file
23
bdfr/tests/site_downloaders/test_vreddit.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.vreddit import VReddit
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'expected_hash'), (
|
||||
('lu8l8g', '93a15642d2f364ae39f00c6d1be354ff'),
|
||||
))
|
||||
def test_find_resources(test_submission_id: str, expected_hash: str, reddit_instance: praw.Reddit):
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
downloader = VReddit(test_submission)
|
||||
resources = downloader.find_resources()
|
||||
assert len(resources) == 1
|
||||
assert isinstance(resources[0], Resource)
|
||||
resources[0].download(120)
|
||||
assert resources[0].hash.hexdigest() == expected_hash
|
||||
26
bdfr/tests/site_downloaders/test_youtube.py
Normal file
26
bdfr/tests/site_downloaders/test_youtube.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.resource import Resource
|
||||
from bdfr.site_downloaders.youtube import Youtube
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
|
||||
('https://www.youtube.com/watch?v=uSm2VDgRIUs', '3c79a62898028987f94161e0abccbddf'),
|
||||
('https://www.youtube.com/watch?v=m-tKnjFwleU', '30314930d853afff8ebc7d8c36a5b833'),
|
||||
))
|
||||
def test_find_resources(test_url: str, expected_hash: str):
|
||||
test_submission = MagicMock()
|
||||
test_submission.url = test_url
|
||||
downloader = Youtube(test_submission)
|
||||
resources = downloader.find_resources()
|
||||
assert len(resources) == 1
|
||||
assert isinstance(resources[0], Resource)
|
||||
resources[0].download(120)
|
||||
assert resources[0].hash.hexdigest() == expected_hash
|
||||
57
bdfr/tests/test_archiver.py
Normal file
57
bdfr/tests/test_archiver.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import praw
|
||||
import pytest
|
||||
|
||||
from bdfr.archive_entry.submission_archive_entry import SubmissionArchiveEntry
|
||||
from bdfr.archiver import Archiver
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize('test_submission_id', (
|
||||
'm3reby',
|
||||
))
|
||||
def test_write_submission_json(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
|
||||
archiver_mock = MagicMock()
|
||||
test_path = Path(tmp_path, 'test.json')
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
archiver_mock.file_name_formatter.format_path.return_value = test_path
|
||||
test_entry = SubmissionArchiveEntry(test_submission)
|
||||
Archiver._write_entry_json(archiver_mock, test_entry)
|
||||
archiver_mock._write_content_to_disk.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize('test_submission_id', (
|
||||
'm3reby',
|
||||
))
|
||||
def test_write_submission_xml(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
|
||||
archiver_mock = MagicMock()
|
||||
test_path = Path(tmp_path, 'test.xml')
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
archiver_mock.file_name_formatter.format_path.return_value = test_path
|
||||
test_entry = SubmissionArchiveEntry(test_submission)
|
||||
Archiver._write_entry_xml(archiver_mock, test_entry)
|
||||
archiver_mock._write_content_to_disk.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize('test_submission_id', (
|
||||
'm3reby',
|
||||
))
|
||||
def test_write_submission_yaml(test_submission_id: str, tmp_path: Path, reddit_instance: praw.Reddit):
|
||||
archiver_mock = MagicMock()
|
||||
archiver_mock.download_directory = tmp_path
|
||||
test_path = Path(tmp_path, 'test.yaml')
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
archiver_mock.file_name_formatter.format_path.return_value = test_path
|
||||
test_entry = SubmissionArchiveEntry(test_submission)
|
||||
Archiver._write_entry_yaml(archiver_mock, test_entry)
|
||||
archiver_mock._write_content_to_disk.assert_called_once()
|
||||
24
bdfr/tests/test_configuration.py
Normal file
24
bdfr/tests/test_configuration.py
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.configuration import Configuration
|
||||
|
||||
|
||||
@pytest.mark.parametrize('arg_dict', (
|
||||
{'directory': 'test_dir'},
|
||||
{
|
||||
'directory': 'test_dir',
|
||||
'no_dupes': True,
|
||||
},
|
||||
))
|
||||
def test_process_click_context(arg_dict: dict):
|
||||
test_config = Configuration()
|
||||
test_context = MagicMock()
|
||||
test_context.params = arg_dict
|
||||
test_config.process_click_arguments(test_context)
|
||||
test_config = vars(test_config)
|
||||
assert all([test_config[arg] == arg_dict[arg] for arg in arg_dict.keys()])
|
||||
58
bdfr/tests/test_download_filter.py
Normal file
58
bdfr/tests/test_download_filter.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.download_filter import DownloadFilter
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def download_filter() -> DownloadFilter:
|
||||
return DownloadFilter(['mp4', 'mp3'], ['test.com', 'reddit.com'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||
('test.mp4', False),
|
||||
('test.avi', True),
|
||||
('test.random.mp3', False),
|
||||
))
|
||||
def test_filter_extension(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||
result = download_filter._check_extension(test_url)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||
('test.mp4', True),
|
||||
('http://reddit.com/test.mp4', False),
|
||||
('http://reddit.com/test.gif', False),
|
||||
('https://www.example.com/test.mp4', True),
|
||||
('https://www.example.com/test.png', True),
|
||||
))
|
||||
def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||
result = download_filter._check_domain(test_url)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||
('test.mp4', False),
|
||||
('test.gif', True),
|
||||
('https://www.example.com/test.mp4', False),
|
||||
('https://www.example.com/test.png', True),
|
||||
('http://reddit.com/test.mp4', False),
|
||||
('http://reddit.com/test.gif', False),
|
||||
))
|
||||
def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||
result = download_filter.check_url(test_url)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_url', (
|
||||
'test.mp3',
|
||||
'test.mp4',
|
||||
'http://reddit.com/test.mp4',
|
||||
't',
|
||||
))
|
||||
def test_filter_empty_filter(test_url: str):
|
||||
download_filter = DownloadFilter()
|
||||
result = download_filter.check_url(test_url)
|
||||
assert result is True
|
||||
430
bdfr/tests/test_downloader.py
Normal file
430
bdfr/tests/test_downloader.py
Normal file
@@ -0,0 +1,430 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import praw
|
||||
import praw.models
|
||||
import pytest
|
||||
|
||||
from bdfr.__main__ import setup_logging
|
||||
from bdfr.configuration import Configuration
|
||||
from bdfr.download_filter import DownloadFilter
|
||||
from bdfr.downloader import RedditDownloader, RedditTypes
|
||||
from bdfr.exceptions import BulkDownloaderException
|
||||
from bdfr.file_name_formatter import FileNameFormatter
|
||||
from bdfr.site_authenticator import SiteAuthenticator
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def args() -> Configuration:
|
||||
args = Configuration()
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def downloader_mock(args: Configuration):
|
||||
downloader_mock = MagicMock()
|
||||
downloader_mock.args = args
|
||||
downloader_mock._sanitise_subreddit_name = RedditDownloader._sanitise_subreddit_name
|
||||
downloader_mock._split_args_input = RedditDownloader._split_args_input
|
||||
downloader_mock.master_hash_list = {}
|
||||
return downloader_mock
|
||||
|
||||
|
||||
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]):
|
||||
results = [sub for res in results for sub in res]
|
||||
assert all([isinstance(res, praw.models.Submission) for res in results])
|
||||
if result_limit is not None:
|
||||
assert len(results) == result_limit
|
||||
return results
|
||||
|
||||
|
||||
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
|
||||
downloader_mock.args.directory = tmp_path / 'test'
|
||||
downloader_mock.config_directories.user_config_dir = tmp_path
|
||||
RedditDownloader._determine_directories(downloader_mock)
|
||||
assert Path(tmp_path / 'test').exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
|
||||
([], []),
|
||||
(['.test'], ['test.com'],),
|
||||
))
|
||||
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
|
||||
downloader_mock.args.skip = skip_extensions
|
||||
downloader_mock.args.skip_domain = skip_domains
|
||||
result = RedditDownloader._create_download_filter(downloader_mock)
|
||||
|
||||
assert isinstance(result, DownloadFilter)
|
||||
assert result.excluded_domains == skip_domains
|
||||
assert result.excluded_extensions == skip_extensions
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_time', 'expected'), (
|
||||
('all', 'all'),
|
||||
('hour', 'hour'),
|
||||
('day', 'day'),
|
||||
('week', 'week'),
|
||||
('random', 'all'),
|
||||
('', 'all'),
|
||||
))
|
||||
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
|
||||
downloader_mock.args.time = test_time
|
||||
result = RedditDownloader._create_time_filter(downloader_mock)
|
||||
|
||||
assert isinstance(result, RedditTypes.TimeType)
|
||||
assert result.name.lower() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_sort', 'expected'), (
|
||||
('', 'hot'),
|
||||
('hot', 'hot'),
|
||||
('controversial', 'controversial'),
|
||||
('new', 'new'),
|
||||
))
|
||||
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
|
||||
downloader_mock.args.sort = test_sort
|
||||
result = RedditDownloader._create_sort_filter(downloader_mock)
|
||||
|
||||
assert isinstance(result, RedditTypes.SortType)
|
||||
assert result.name.lower() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
|
||||
('{POSTID}', '{SUBREDDIT}'),
|
||||
('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
|
||||
('{POSTID}', 'test'),
|
||||
('{POSTID}', ''),
|
||||
('{POSTID}', '{SUBREDDIT}/{REDDITOR}'),
|
||||
))
|
||||
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
|
||||
downloader_mock.args.file_scheme = test_file_scheme
|
||||
downloader_mock.args.folder_scheme = test_folder_scheme
|
||||
result = RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||
|
||||
assert isinstance(result, FileNameFormatter)
|
||||
assert result.file_format_string == test_file_scheme
|
||||
assert result.directory_format_string == test_folder_scheme.split('/')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
|
||||
('', ''),
|
||||
('', '{SUBREDDIT}'),
|
||||
('test', '{SUBREDDIT}'),
|
||||
))
|
||||
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
|
||||
downloader_mock.args.file_scheme = test_file_scheme
|
||||
downloader_mock.args.folder_scheme = test_folder_scheme
|
||||
with pytest.raises(BulkDownloaderException):
|
||||
RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||
|
||||
|
||||
def test_create_authenticator(downloader_mock: MagicMock):
|
||||
result = RedditDownloader._create_authenticator(downloader_mock)
|
||||
assert isinstance(result, SiteAuthenticator)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize('test_submission_ids', (
|
||||
('lvpf4l',),
|
||||
('lvpf4l', 'lvqnsn'),
|
||||
('lvpf4l', 'lvqnsn', 'lvl9kd'),
|
||||
))
|
||||
def test_get_submissions_from_link(
|
||||
test_submission_ids: list[str],
|
||||
reddit_instance: praw.Reddit,
|
||||
downloader_mock: MagicMock):
|
||||
downloader_mock.args.link = test_submission_ids
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
results = RedditDownloader._get_submissions_from_link(downloader_mock)
|
||||
assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
|
||||
assert len(results[0]) == len(test_submission_ids)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_subreddits', 'limit'), (
|
||||
(('Futurology',), 10),
|
||||
(('Futurology', 'Mindustry, Python'), 10),
|
||||
(('Futurology',), 20),
|
||||
(('Futurology', 'Python'), 10),
|
||||
(('Futurology',), 100),
|
||||
(('Futurology',), 0),
|
||||
))
|
||||
def test_get_subreddit_normal(
|
||||
test_subreddits: list[str],
|
||||
limit: int,
|
||||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.subreddit = test_subreddits
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||
test_subreddits = downloader_mock._split_args_input(test_subreddits)
|
||||
results = assert_all_results_are_submissions(
|
||||
(limit * len(test_subreddits)) if limit else None, results)
|
||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), (
|
||||
(('Python',), 'scraper', 10),
|
||||
(('Python',), '', 10),
|
||||
(('Python',), 'djsdsgewef', 0),
|
||||
))
|
||||
def test_get_subreddit_search(
|
||||
test_subreddits: list[str],
|
||||
search_term: str,
|
||||
limit: int,
|
||||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.search = search_term
|
||||
downloader_mock.args.subreddit = test_subreddits
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
results = RedditDownloader._get_subreddits(downloader_mock)
|
||||
results = assert_all_results_are_submissions(
|
||||
(limit * len(test_subreddits)) if limit else None, results)
|
||||
assert all([res.subreddit.display_name in test_subreddits for res in results])
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
|
||||
('helen_darten', ('cuteanimalpics',), 10),
|
||||
('korfor', ('chess',), 100),
|
||||
))
|
||||
# Good sources at https://www.reddit.com/r/multihub/
|
||||
def test_get_multireddits_public(
|
||||
test_user: str,
|
||||
test_multireddits: list[str],
|
||||
limit: int,
|
||||
reddit_instance: praw.Reddit,
|
||||
downloader_mock: MagicMock):
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock.args.multireddit = test_multireddits
|
||||
downloader_mock.args.user = test_user
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
results = RedditDownloader._get_multireddits(downloader_mock)
|
||||
assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_user', 'limit'), (
|
||||
('danigirl3694', 10),
|
||||
('danigirl3694', 50),
|
||||
('CapitanHam', None),
|
||||
))
|
||||
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
|
||||
downloader_mock.args.limit = limit
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
downloader_mock.args.submitted = True
|
||||
downloader_mock.args.user = test_user
|
||||
downloader_mock.authenticated = False
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
results = RedditDownloader._get_user_data(downloader_mock)
|
||||
results = assert_all_results_are_submissions(limit, results)
|
||||
assert all([res.author.name == test_user for res in results])
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.authenticated
|
||||
@pytest.mark.parametrize('test_flag', (
|
||||
'upvoted',
|
||||
'saved',
|
||||
))
|
||||
def test_get_user_authenticated_lists(
|
||||
test_flag: str,
|
||||
downloader_mock: MagicMock,
|
||||
authenticated_reddit_instance: praw.Reddit,
|
||||
):
|
||||
downloader_mock.args.__dict__[test_flag] = True
|
||||
downloader_mock.reddit_instance = authenticated_reddit_instance
|
||||
downloader_mock.args.user = 'me'
|
||||
downloader_mock.args.limit = 10
|
||||
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
|
||||
downloader_mock.sort_filter = RedditTypes.SortType.HOT
|
||||
RedditDownloader._resolve_user_name(downloader_mock)
|
||||
results = RedditDownloader._get_user_data(downloader_mock)
|
||||
assert_all_results_are_submissions(10, results)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), (
|
||||
('ljyy27', 4),
|
||||
))
|
||||
def test_download_submission(
|
||||
test_submission_id: str,
|
||||
expected_files_len: int,
|
||||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit,
|
||||
tmp_path: Path):
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.download_filter.check_url.return_value = True
|
||||
downloader_mock.args.folder_scheme = ''
|
||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||
downloader_mock.download_directory = tmp_path
|
||||
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
||||
RedditDownloader._download_submission(downloader_mock, submission)
|
||||
folder_contents = list(tmp_path.iterdir())
|
||||
assert len(folder_contents) == expected_files_len
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
def test_download_submission_file_exists(
|
||||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit,
|
||||
tmp_path: Path,
|
||||
capsys: pytest.CaptureFixture
|
||||
):
|
||||
setup_logging(3)
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.download_filter.check_url.return_value = True
|
||||
downloader_mock.args.folder_scheme = ''
|
||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||
downloader_mock.download_directory = tmp_path
|
||||
submission = downloader_mock.reddit_instance.submission(id='m1hqw6')
|
||||
Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch()
|
||||
RedditDownloader._download_submission(downloader_mock, submission)
|
||||
folder_contents = list(tmp_path.iterdir())
|
||||
output = capsys.readouterr()
|
||||
assert len(folder_contents) == 1
|
||||
assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png already exists' in output.out
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'test_hash'), (
|
||||
('m1hqw6', 'a912af8905ae468e0121e9940f797ad7'),
|
||||
))
|
||||
def test_download_submission_hash_exists(
|
||||
test_submission_id: str,
|
||||
test_hash: str,
|
||||
downloader_mock: MagicMock,
|
||||
reddit_instance: praw.Reddit,
|
||||
tmp_path: Path,
|
||||
capsys: pytest.CaptureFixture
|
||||
):
|
||||
setup_logging(3)
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.download_filter.check_url.return_value = True
|
||||
downloader_mock.args.folder_scheme = ''
|
||||
downloader_mock.args.no_dupes = True
|
||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||
downloader_mock.download_directory = tmp_path
|
||||
downloader_mock.master_hash_list = {test_hash: None}
|
||||
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
||||
RedditDownloader._download_submission(downloader_mock, submission)
|
||||
folder_contents = list(tmp_path.iterdir())
|
||||
output = capsys.readouterr()
|
||||
assert len(folder_contents) == 0
|
||||
assert re.search(r'Resource hash .*? downloaded elsewhere', output.out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_name', 'expected'), (
|
||||
('Mindustry', 'Mindustry'),
|
||||
('Futurology', 'Futurology'),
|
||||
('r/Mindustry', 'Mindustry'),
|
||||
('TrollXChromosomes', 'TrollXChromosomes'),
|
||||
('r/TrollXChromosomes', 'TrollXChromosomes'),
|
||||
('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
|
||||
('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
|
||||
('https://www.reddit.com/r/Futurology/', 'Futurology'),
|
||||
('https://www.reddit.com/r/Futurology', 'Futurology'),
|
||||
))
|
||||
def test_sanitise_subreddit_name(test_name: str, expected: str):
|
||||
result = RedditDownloader._sanitise_subreddit_name(test_name)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_search_existing_files():
|
||||
results = RedditDownloader.scan_existing_files(Path('.'))
|
||||
assert len(results.keys()) >= 40
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
|
||||
(['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||
(['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||
(['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||
(['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
|
||||
(['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'})
|
||||
))
|
||||
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
|
||||
results = RedditDownloader._split_args_input(test_subreddit_entries)
|
||||
assert results == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize('test_submission_id', (
|
||||
'm1hqw6',
|
||||
))
|
||||
def test_mark_hard_link(
|
||||
test_submission_id: str,
|
||||
downloader_mock: MagicMock,
|
||||
tmp_path: Path,
|
||||
reddit_instance: praw.Reddit
|
||||
):
|
||||
downloader_mock.reddit_instance = reddit_instance
|
||||
downloader_mock.args.make_hard_links = True
|
||||
downloader_mock.download_directory = tmp_path
|
||||
downloader_mock.args.folder_scheme = ''
|
||||
downloader_mock.args.file_scheme = '{POSTID}'
|
||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||
submission = downloader_mock.reddit_instance.submission(id=test_submission_id)
|
||||
original = Path(tmp_path, f'{test_submission_id}.png')
|
||||
|
||||
RedditDownloader._download_submission(downloader_mock, submission)
|
||||
assert original.exists()
|
||||
|
||||
downloader_mock.args.file_scheme = 'test2_{POSTID}'
|
||||
downloader_mock.file_name_formatter = RedditDownloader._create_file_name_formatter(downloader_mock)
|
||||
RedditDownloader._download_submission(downloader_mock, submission)
|
||||
test_file_1_stats = original.stat()
|
||||
test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino
|
||||
|
||||
assert test_file_1_stats.st_nlink == 2
|
||||
assert test_file_1_stats.st_ino == test_file_2_inode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), (
|
||||
(('aaaaaa',), (), 1),
|
||||
(('aaaaaa',), ('aaaaaa',), 0),
|
||||
((), ('aaaaaa',), 0),
|
||||
(('aaaaaa', 'bbbbbb'), ('aaaaaa',), 1),
|
||||
))
|
||||
def test_excluded_ids(test_ids: tuple[str], test_excluded: tuple[str], expected_len: int, downloader_mock: MagicMock):
|
||||
downloader_mock.excluded_submission_ids = test_excluded
|
||||
test_submissions = []
|
||||
for test_id in test_ids:
|
||||
m = MagicMock()
|
||||
m.id = test_id
|
||||
test_submissions.append(m)
|
||||
downloader_mock.reddit_lists = [test_submissions]
|
||||
RedditDownloader.download(downloader_mock)
|
||||
assert downloader_mock._download_submission.call_count == expected_len
|
||||
|
||||
|
||||
def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
|
||||
test_file = tmp_path / 'test.txt'
|
||||
test_file.write_text('aaaaaa\nbbbbbb')
|
||||
downloader_mock.args.exclude_id_file = [test_file]
|
||||
results = RedditDownloader._read_excluded_ids(downloader_mock)
|
||||
assert results == {'aaaaaa', 'bbbbbb'}
|
||||
307
bdfr/tests/test_file_name_formatter.py
Normal file
307
bdfr/tests/test_file_name_formatter.py
Normal file
@@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import praw.models
|
||||
import pytest
|
||||
|
||||
from bdfr.file_name_formatter import FileNameFormatter
|
||||
from bdfr.resource import Resource
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def submission() -> MagicMock:
|
||||
test = MagicMock()
|
||||
test.title = 'name'
|
||||
test.subreddit.display_name = 'randomreddit'
|
||||
test.author.name = 'person'
|
||||
test.id = '12345'
|
||||
test.score = 1000
|
||||
test.link_flair_text = 'test_flair'
|
||||
test.created_utc = 123456789
|
||||
test.__class__ = praw.models.Submission
|
||||
return test
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def reddit_submission(reddit_instance: praw.Reddit) -> praw.models.Submission:
|
||||
return reddit_instance.submission(id='lgilgt')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('format_string', 'expected'), (
|
||||
('{SUBREDDIT}', 'randomreddit'),
|
||||
('{REDDITOR}', 'person'),
|
||||
('{POSTID}', '12345'),
|
||||
('{UPVOTES}', '1000'),
|
||||
('{FLAIR}', 'test_flair'),
|
||||
('{DATE}', '123456789'),
|
||||
('{REDDITOR}_{TITLE}_{POSTID}', 'person_name_12345'),
|
||||
('{RANDOM}', '{RANDOM}'),
|
||||
))
|
||||
def test_format_name_mock(format_string: str, expected: str, submission: MagicMock):
|
||||
result = FileNameFormatter._format_name(submission, format_string)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_string', 'expected'), (
|
||||
('', False),
|
||||
('test', False),
|
||||
('{POSTID}', True),
|
||||
('POSTID', False),
|
||||
('{POSTID}_test', True),
|
||||
('test_{TITLE}', True),
|
||||
('TITLE_POSTID', False),
|
||||
))
|
||||
def test_check_format_string_validity(test_string: str, expected: bool):
|
||||
result = FileNameFormatter.validate_string(test_string)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('format_string', 'expected'), (
|
||||
('{SUBREDDIT}', 'Mindustry'),
|
||||
('{REDDITOR}', 'Gamer_player_boi'),
|
||||
('{POSTID}', 'lgilgt'),
|
||||
('{FLAIR}', 'Art'),
|
||||
('{SUBREDDIT}_{TITLE}', 'Mindustry_Toxopid that is NOT humane >:('),
|
||||
('{REDDITOR}_{TITLE}_{POSTID}', 'Gamer_player_boi_Toxopid that is NOT humane >:(_lgilgt')
|
||||
))
|
||||
def test_format_name_real(format_string: str, expected: str, reddit_submission: praw.models.Submission):
|
||||
result = FileNameFormatter._format_name(reddit_submission, format_string)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('format_string_directory', 'format_string_file', 'expected'), (
|
||||
(
|
||||
'{SUBREDDIT}',
|
||||
'{POSTID}',
|
||||
'test/Mindustry/lgilgt.png',
|
||||
),
|
||||
(
|
||||
'{SUBREDDIT}',
|
||||
'{TITLE}_{POSTID}',
|
||||
'test/Mindustry/Toxopid that is NOT humane >:(_lgilgt.png',
|
||||
),
|
||||
(
|
||||
'{SUBREDDIT}',
|
||||
'{REDDITOR}_{TITLE}_{POSTID}',
|
||||
'test/Mindustry/Gamer_player_boi_Toxopid that is NOT humane >:(_lgilgt.png',
|
||||
),
|
||||
))
|
||||
def test_format_full(
|
||||
format_string_directory: str,
|
||||
format_string_file: str,
|
||||
expected: str,
|
||||
reddit_submission: praw.models.Submission):
|
||||
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
|
||||
test_formatter = FileNameFormatter(format_string_file, format_string_directory)
|
||||
result = test_formatter.format_path(test_resource, Path('test'))
|
||||
assert str(result) == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('format_string_directory', 'format_string_file'), (
|
||||
('{SUBREDDIT}', '{POSTID}'),
|
||||
('{SUBREDDIT}', '{UPVOTES}'),
|
||||
('{SUBREDDIT}', '{UPVOTES}{POSTID}'),
|
||||
))
|
||||
def test_format_full_conform(
|
||||
format_string_directory: str,
|
||||
format_string_file: str,
|
||||
reddit_submission: praw.models.Submission):
|
||||
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
|
||||
test_formatter = FileNameFormatter(format_string_file, format_string_directory)
|
||||
test_formatter.format_path(test_resource, Path('test'))
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('format_string_directory', 'format_string_file', 'index', 'expected'), (
|
||||
('{SUBREDDIT}', '{POSTID}', None, 'test/Mindustry/lgilgt.png'),
|
||||
('{SUBREDDIT}', '{POSTID}', 1, 'test/Mindustry/lgilgt_1.png'),
|
||||
('{SUBREDDIT}', '{POSTID}', 2, 'test/Mindustry/lgilgt_2.png'),
|
||||
('{SUBREDDIT}', '{TITLE}_{POSTID}', 2, 'test/Mindustry/Toxopid that is NOT humane >:(_lgilgt_2.png'),
|
||||
))
|
||||
def test_format_full_with_index_suffix(
|
||||
format_string_directory: str,
|
||||
format_string_file: str,
|
||||
index: Optional[int],
|
||||
expected: str,
|
||||
reddit_submission: praw.models.Submission):
|
||||
test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png')
|
||||
test_formatter = FileNameFormatter(format_string_file, format_string_directory)
|
||||
result = test_formatter.format_path(test_resource, Path('test'), index)
|
||||
assert str(result) == expected
|
||||
|
||||
|
||||
def test_format_multiple_resources():
|
||||
mocks = []
|
||||
for i in range(1, 5):
|
||||
new_mock = MagicMock()
|
||||
new_mock.url = 'https://example.com/test.png'
|
||||
new_mock.extension = '.png'
|
||||
new_mock.source_submission.title = 'test'
|
||||
new_mock.source_submission.__class__ = praw.models.Submission
|
||||
mocks.append(new_mock)
|
||||
test_formatter = FileNameFormatter('{TITLE}', '')
|
||||
results = test_formatter.format_resource_paths(mocks, Path('.'))
|
||||
results = set([str(res[0]) for res in results])
|
||||
assert results == {'test_1.png', 'test_2.png', 'test_3.png', 'test_4.png'}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_filename', 'test_ending'), (
|
||||
('A' * 300, '.png'),
|
||||
('A' * 300, '_1.png'),
|
||||
('a' * 300, '_1000.jpeg'),
|
||||
('😍💕✨' * 100, '_1.png'),
|
||||
))
|
||||
def test_limit_filename_length(test_filename: str, test_ending: str):
|
||||
result = FileNameFormatter._limit_file_name_length(test_filename, test_ending)
|
||||
assert len(result) <= 255
|
||||
assert len(result.encode('utf-8')) <= 255
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_filename', 'test_ending', 'expected_end'), (
|
||||
('test_aaaaaa', '_1.png', 'test_aaaaaa_1.png'),
|
||||
('test_aataaa', '_1.png', 'test_aataaa_1.png'),
|
||||
('test_abcdef', '_1.png', 'test_abcdef_1.png'),
|
||||
('test_aaaaaa', '.png', 'test_aaaaaa.png'),
|
||||
('test', '_1.png', 'test_1.png'),
|
||||
('test_m1hqw6', '_1.png', 'test_m1hqw6_1.png'),
|
||||
('A' * 300 + '_bbbccc', '.png', '_bbbccc.png'),
|
||||
('A' * 300 + '_bbbccc', '_1000.jpeg', '_bbbccc_1000.jpeg'),
|
||||
('😍💕✨' * 100 + '_aaa1aa', '_1.png', '_aaa1aa_1.png'),
|
||||
))
|
||||
def test_preserve_id_append_when_shortening(test_filename: str, test_ending: str, expected_end: str):
|
||||
result = FileNameFormatter._limit_file_name_length(test_filename, test_ending)
|
||||
assert len(result) <= 255
|
||||
assert len(result.encode('utf-8')) <= 255
|
||||
assert isinstance(result, str)
|
||||
assert result.endswith(expected_end)
|
||||
|
||||
|
||||
def test_shorten_filenames(submission: MagicMock, tmp_path: Path):
|
||||
submission.title = 'A' * 300
|
||||
submission.author.name = 'test'
|
||||
submission.subreddit.display_name = 'test'
|
||||
submission.id = 'BBBBBB'
|
||||
test_resource = Resource(submission, 'www.example.com/empty', '.jpeg')
|
||||
test_formatter = FileNameFormatter('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}')
|
||||
result = test_formatter.format_path(test_resource, tmp_path)
|
||||
result.parent.mkdir(parents=True)
|
||||
result.touch()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_string', 'expected'), (
|
||||
('test', 'test'),
|
||||
('test😍', 'test'),
|
||||
('test.png', 'test.png'),
|
||||
('test*', 'test'),
|
||||
('test**', 'test'),
|
||||
('test?*', 'test'),
|
||||
('test_???.png', 'test_.png'),
|
||||
('test_???😍.png', 'test_.png'),
|
||||
))
|
||||
def test_format_file_name_for_windows(test_string: str, expected: str):
|
||||
result = FileNameFormatter._format_for_windows(test_string)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_string', 'expected'), (
|
||||
('test', 'test'),
|
||||
('test😍', 'test'),
|
||||
('😍', ''),
|
||||
))
|
||||
def test_strip_emojies(test_string: str, expected: str):
|
||||
result = FileNameFormatter._strip_emojis(test_string)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'expected'), (
|
||||
('mfuteh', {
|
||||
'title': 'Why Do Interviewers Ask Linked List Questions?',
|
||||
'redditor': 'mjgardner',
|
||||
}),
|
||||
))
|
||||
def test_generate_dict_for_submission(test_submission_id: str, expected: dict, reddit_instance: praw.Reddit):
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
result = FileNameFormatter._generate_name_dict_from_submission(test_submission)
|
||||
assert all([result.get(key) == expected[key] for key in expected.keys()])
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_comment_id', 'expected'), (
|
||||
('gsq0yuw', {
|
||||
'title': 'Why Do Interviewers Ask Linked List Questions?',
|
||||
'redditor': 'Doctor-Dapper',
|
||||
'postid': 'gsq0yuw',
|
||||
'flair': '',
|
||||
}),
|
||||
))
|
||||
def test_generate_dict_for_comment(test_comment_id: str, expected: dict, reddit_instance: praw.Reddit):
|
||||
test_comment = reddit_instance.comment(id=test_comment_id)
|
||||
result = FileNameFormatter._generate_name_dict_from_comment(test_comment)
|
||||
assert all([result.get(key) == expected[key] for key in expected.keys()])
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme', 'test_comment_id', 'expected_name'), (
|
||||
('{POSTID}', '', 'gsoubde', 'gsoubde.json'),
|
||||
('{REDDITOR}_{POSTID}', '', 'gsoubde', 'DELETED_gsoubde.json'),
|
||||
))
|
||||
def test_format_archive_entry_comment(
|
||||
test_file_scheme: str,
|
||||
test_folder_scheme: str,
|
||||
test_comment_id: str,
|
||||
expected_name: str,
|
||||
tmp_path: Path,
|
||||
reddit_instance: praw.Reddit,
|
||||
):
|
||||
test_comment = reddit_instance.comment(id=test_comment_id)
|
||||
test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme)
|
||||
test_entry = Resource(test_comment, '', '.json')
|
||||
result = test_formatter.format_path(test_entry, tmp_path)
|
||||
assert result.name == expected_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_folder_scheme', 'expected'), (
|
||||
('{REDDITOR}/{SUBREDDIT}', 'person/randomreddit'),
|
||||
('{POSTID}/{SUBREDDIT}/{REDDITOR}', '12345/randomreddit/person'),
|
||||
))
|
||||
def test_multilevel_folder_scheme(
|
||||
test_folder_scheme: str,
|
||||
expected: str,
|
||||
tmp_path: Path,
|
||||
submission: MagicMock,
|
||||
):
|
||||
test_formatter = FileNameFormatter('{POSTID}', test_folder_scheme)
|
||||
test_resource = MagicMock()
|
||||
test_resource.source_submission = submission
|
||||
test_resource.extension = '.png'
|
||||
result = test_formatter.format_path(test_resource, tmp_path)
|
||||
result = result.relative_to(tmp_path)
|
||||
assert str(result.parent) == expected
|
||||
assert len(result.parents) == (len(expected.split('/')) + 1)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.parametrize(('test_submission_id', 'test_file_scheme', 'expected'), (
|
||||
('mecwk7', '{TITLE}', 'My cat’s paws are so cute'), # Unicode escape in title
|
||||
))
|
||||
def test_edge_case_names(test_submission_id: str, test_file_scheme: str, expected: str, reddit_instance: praw.Reddit):
|
||||
test_submission = reddit_instance.submission(id=test_submission_id)
|
||||
result = FileNameFormatter._format_name(test_submission, test_file_scheme)
|
||||
assert result == expected
|
||||
299
bdfr/tests/test_integration.py
Normal file
299
bdfr/tests/test_integration.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from bdfr.__main__ import cli
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['-s', 'Mindustry', '-L', 1],
|
||||
['-s', 'r/Mindustry', '-L', 1],
|
||||
['-s', 'r/mindustry', '-L', 1],
|
||||
['-s', 'mindustry', '-L', 1],
|
||||
['-s', 'https://www.reddit.com/r/TrollXChromosomes/', '-L', 1],
|
||||
['-s', 'r/TrollXChromosomes/', '-L', 1],
|
||||
['-s', 'TrollXChromosomes/', '-L', 1],
|
||||
['-s', 'trollxchromosomes', '-L', 1],
|
||||
['-s', 'trollxchromosomes,mindustry,python', '-L', 1],
|
||||
['-s', 'trollxchromosomes, mindustry, python', '-L', 1],
|
||||
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day'],
|
||||
['-s', 'trollxchromosomes', '-L', 1, '--sort', 'new'],
|
||||
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--sort', 'new'],
|
||||
['-s', 'trollxchromosomes', '-L', 1, '--search', 'women'],
|
||||
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--search', 'women'],
|
||||
['-s', 'trollxchromosomes', '-L', 1, '--sort', 'new', '--search', 'women'],
|
||||
['-s', 'trollxchromosomes', '-L', 1, '--time', 'day', '--sort', 'new', '--search', 'women'],
|
||||
))
|
||||
def test_cli_download_subreddits(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Added submissions from subreddit ' in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['-l', 'm2601g'],
|
||||
['-l', 'https://www.reddit.com/r/TrollXChromosomes/comments/m2601g/its_a_step_in_the_right_direction/'],
|
||||
['-l', 'm3hxzd'], # Really long title used to overflow filename limit
|
||||
['-l', 'm3kua3'], # Has a deleted user
|
||||
['-l', 'm5bqkf'], # Resource leading to a 404
|
||||
))
|
||||
def test_cli_download_links(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10],
|
||||
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--sort', 'rising'],
|
||||
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--time', 'week'],
|
||||
['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--time', 'week', '--sort', 'rising'],
|
||||
))
|
||||
def test_cli_download_multireddit(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Added submissions from multireddit ' in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--user', 'helen_darten', '-m', 'xxyyzzqwerty', '-L', 10],
|
||||
))
|
||||
def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Failed to get submissions for multireddit' in result.output
|
||||
assert 'received 404 HTTP response' in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.authenticated
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--user', 'me', '--upvoted', '--authenticate', '-L', 10],
|
||||
['--user', 'me', '--saved', '--authenticate', '-L', 10],
|
||||
['--user', 'me', '--submitted', '--authenticate', '-L', 10],
|
||||
['--user', 'djnish', '--submitted', '-L', 10],
|
||||
['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'],
|
||||
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'],
|
||||
['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial', '--time', 'month'],
|
||||
))
|
||||
def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Downloaded submission ' in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.authenticated
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--user', 'me', '-L', 10, '--folder-scheme', ''],
|
||||
))
|
||||
def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'To use "me" as a user, an authenticated Reddit instance must be used' in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--subreddit', 'python', '-L', 10, '--search-existing'],
|
||||
))
|
||||
def test_cli_download_search_existing(test_args: list[str], tmp_path: Path):
|
||||
Path(tmp_path, 'test.txt').touch()
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Calculating hashes for' in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--subreddit', 'tumblr', '-L', '25', '--skip', 'png', '--skip', 'jpg'],
|
||||
))
|
||||
def test_cli_download_download_filters(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Download filter removed submission' in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--subreddit', 'all', '-L', '100', '--sort', 'new'],
|
||||
))
|
||||
def test_cli_download_long(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['-l', 'gstd4hk'],
|
||||
['-l', 'm2601g'],
|
||||
))
|
||||
def test_cli_archive_single(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['archive', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert re.search(r'Writing entry .*? to file in .*? format', result.output)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--subreddit', 'Mindustry', '-L', 25],
|
||||
['--subreddit', 'Mindustry', '-L', 25, '--format', 'xml'],
|
||||
['--subreddit', 'Mindustry', '-L', 25, '--format', 'yaml'],
|
||||
['--subreddit', 'Mindustry', '-L', 25, '--sort', 'new'],
|
||||
['--subreddit', 'Mindustry', '-L', 25, '--time', 'day'],
|
||||
['--subreddit', 'Mindustry', '-L', 25, '--time', 'day', '--sort', 'new'],
|
||||
))
|
||||
def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['archive', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert re.search(r'Writing entry .*? to file in .*? format', result.output)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--user', 'me', '--authenticate', '--all-comments', '-L', '10'],
|
||||
))
|
||||
def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['archive', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--subreddit', 'all', '-L', 100],
|
||||
['--subreddit', 'all', '-L', 100, '--sort', 'new'],
|
||||
))
|
||||
def test_cli_archive_long(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['archive', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert re.search(r'Writing entry .*? to file in .*? format', result.output)
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--user', 'sdclhgsolgjeroij', '--submitted', '-L', 10],
|
||||
['--user', 'me', '--upvoted', '-L', 10],
|
||||
['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10],
|
||||
))
|
||||
def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--time', 'random'],
|
||||
['--sort', 'random'],
|
||||
))
|
||||
def test_cli_download_hard_fail(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
def test_cli_download_use_default_config(tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', '-vv', str(tmp_path)]
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['-l', 'm2601g', '--exclude-id', 'm2601g'],
|
||||
))
|
||||
def test_cli_download_links_exclusion(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'in exclusion list' in result.output
|
||||
assert 'Downloaded submission ' not in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(Path('test_config.cfg') is False, reason='A test config file is required for integration tests')
|
||||
@pytest.mark.parametrize('test_args', (
|
||||
['--file-scheme', '{TITLE}'],
|
||||
['--file-scheme', '{TITLE}_test_{SUBREDDIT}'],
|
||||
))
|
||||
def test_cli_file_scheme_warning(test_args: list[str], tmp_path: Path):
|
||||
runner = CliRunner()
|
||||
test_args = ['download', str(tmp_path), '-v', '--config', 'test_config.cfg'] + test_args
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Some files might not be downloaded due to name conflicts' in result.output
|
||||
71
bdfr/tests/test_oauth2.py
Normal file
71
bdfr/tests/test_oauth2.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from bdfr.exceptions import BulkDownloaderException
|
||||
from bdfr.oauth2 import OAuth2Authenticator, OAuth2TokenManager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def example_config() -> configparser.ConfigParser:
|
||||
out = configparser.ConfigParser()
|
||||
config_dict = {'DEFAULT': {'user_token': 'example'}}
|
||||
out.read_dict(config_dict)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize('test_scopes', (
|
||||
{'history', },
|
||||
{'history', 'creddits'},
|
||||
{'account', 'flair'},
|
||||
{'*', },
|
||||
))
|
||||
def test_check_scopes(test_scopes: set[str]):
|
||||
OAuth2Authenticator._check_scopes(test_scopes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_scopes', 'expected'), (
|
||||
('history', {'history', }),
|
||||
('history creddits', {'history', 'creddits'}),
|
||||
('history, creddits, account', {'history', 'creddits', 'account'}),
|
||||
('history,creddits,account,flair', {'history', 'creddits', 'account', 'flair'}),
|
||||
))
|
||||
def test_split_scopes(test_scopes: str, expected: set[str]):
|
||||
result = OAuth2Authenticator.split_scopes(test_scopes)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize('test_scopes', (
|
||||
{'random', },
|
||||
{'scope', 'another_scope'},
|
||||
))
|
||||
def test_check_scopes_bad(test_scopes: set[str]):
|
||||
with pytest.raises(BulkDownloaderException):
|
||||
OAuth2Authenticator._check_scopes(test_scopes)
|
||||
|
||||
|
||||
def test_token_manager_read(example_config: configparser.ConfigParser):
|
||||
mock_authoriser = MagicMock()
|
||||
mock_authoriser.refresh_token = None
|
||||
test_manager = OAuth2TokenManager(example_config, MagicMock())
|
||||
test_manager.pre_refresh_callback(mock_authoriser)
|
||||
assert mock_authoriser.refresh_token == example_config.get('DEFAULT', 'user_token')
|
||||
|
||||
|
||||
def test_token_manager_write(example_config: configparser.ConfigParser, tmp_path: Path):
|
||||
test_path = tmp_path / 'test.cfg'
|
||||
mock_authoriser = MagicMock()
|
||||
mock_authoriser.refresh_token = 'changed_token'
|
||||
test_manager = OAuth2TokenManager(example_config, test_path)
|
||||
test_manager.post_refresh_callback(mock_authoriser)
|
||||
assert example_config.get('DEFAULT', 'user_token') == 'changed_token'
|
||||
with open(test_path, 'r') as file:
|
||||
file_contents = file.read()
|
||||
assert 'user_token = changed_token' in file_contents
|
||||
32
bdfr/tests/test_resource.py
Normal file
32
bdfr/tests/test_resource.py
Normal file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from bdfr.resource import Resource
|
||||
|
||||
|
||||
@pytest.mark.parametrize(('test_url', 'expected'), (
|
||||
('test.png', '.png'),
|
||||
('another.mp4', '.mp4'),
|
||||
('test.jpeg', '.jpeg'),
|
||||
('http://www.random.com/resource.png', '.png'),
|
||||
('https://www.resource.com/test/example.jpg', '.jpg'),
|
||||
('hard.png.mp4', '.mp4'),
|
||||
('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'),
|
||||
))
|
||||
def test_resource_get_extension(test_url: str, expected: str):
|
||||
test_resource = Resource(MagicMock(), test_url)
|
||||
result = test_resource._determine_extension()
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.parametrize(('test_url', 'expected_hash'), (
|
||||
('https://www.iana.org/_img/2013.1/iana-logo-header.svg', '426b3ac01d3584c820f3b7f5985d6623'),
|
||||
))
|
||||
def test_download_online_resource(test_url: str, expected_hash: str):
|
||||
test_resource = Resource(MagicMock(), test_url)
|
||||
test_resource.download(120)
|
||||
assert test_resource.hash.hexdigest() == expected_hash
|
||||
Reference in New Issue
Block a user