Format according to the black standard

This commit is contained in:
Serene-Arc
2022-12-03 15:11:17 +10:00
parent 96cd7d7147
commit 0873a4a2b2
60 changed files with 2160 additions and 1790 deletions

View File

@@ -13,53 +13,54 @@ from bdfr.downloader import RedditDownloader
logger = logging.getLogger()
_common_options = [
click.argument('directory', type=str),
click.option('--authenticate', is_flag=True, default=None),
click.option('--config', type=str, default=None),
click.option('--opts', type=str, default=None),
click.option('--disable-module', multiple=True, default=None, type=str),
click.option('--exclude-id', default=None, multiple=True),
click.option('--exclude-id-file', default=None, multiple=True),
click.option('--file-scheme', default=None, type=str),
click.option('--folder-scheme', default=None, type=str),
click.option('--ignore-user', type=str, multiple=True, default=None),
click.option('--include-id-file', multiple=True, default=None),
click.option('--log', type=str, default=None),
click.option('--saved', is_flag=True, default=None),
click.option('--search', default=None, type=str),
click.option('--submitted', is_flag=True, default=None),
click.option('--subscribed', is_flag=True, default=None),
click.option('--time-format', type=str, default=None),
click.option('--upvoted', is_flag=True, default=None),
click.option('-L', '--limit', default=None, type=int),
click.option('-l', '--link', multiple=True, default=None, type=str),
click.option('-m', '--multireddit', multiple=True, default=None, type=str),
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', 'controversial', 'rising', 'relevance')),
default=None),
click.option('-s', '--subreddit', multiple=True, default=None, type=str),
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
click.option('-u', '--user', type=str, multiple=True, default=None),
click.option('-v', '--verbose', default=None, count=True),
click.argument("directory", type=str),
click.option("--authenticate", is_flag=True, default=None),
click.option("--config", type=str, default=None),
click.option("--opts", type=str, default=None),
click.option("--disable-module", multiple=True, default=None, type=str),
click.option("--exclude-id", default=None, multiple=True),
click.option("--exclude-id-file", default=None, multiple=True),
click.option("--file-scheme", default=None, type=str),
click.option("--folder-scheme", default=None, type=str),
click.option("--ignore-user", type=str, multiple=True, default=None),
click.option("--include-id-file", multiple=True, default=None),
click.option("--log", type=str, default=None),
click.option("--saved", is_flag=True, default=None),
click.option("--search", default=None, type=str),
click.option("--submitted", is_flag=True, default=None),
click.option("--subscribed", is_flag=True, default=None),
click.option("--time-format", type=str, default=None),
click.option("--upvoted", is_flag=True, default=None),
click.option("-L", "--limit", default=None, type=int),
click.option("-l", "--link", multiple=True, default=None, type=str),
click.option("-m", "--multireddit", multiple=True, default=None, type=str),
click.option(
"-S", "--sort", type=click.Choice(("hot", "top", "new", "controversial", "rising", "relevance")), default=None
),
click.option("-s", "--subreddit", multiple=True, default=None, type=str),
click.option("-t", "--time", type=click.Choice(("all", "hour", "day", "week", "month", "year")), default=None),
click.option("-u", "--user", type=str, multiple=True, default=None),
click.option("-v", "--verbose", default=None, count=True),
]
_downloader_options = [
click.option('--make-hard-links', is_flag=True, default=None),
click.option('--max-wait-time', type=int, default=None),
click.option('--no-dupes', is_flag=True, default=None),
click.option('--search-existing', is_flag=True, default=None),
click.option('--skip', default=None, multiple=True),
click.option('--skip-domain', default=None, multiple=True),
click.option('--skip-subreddit', default=None, multiple=True),
click.option('--min-score', type=int, default=None),
click.option('--max-score', type=int, default=None),
click.option('--min-score-ratio', type=float, default=None),
click.option('--max-score-ratio', type=float, default=None),
click.option("--make-hard-links", is_flag=True, default=None),
click.option("--max-wait-time", type=int, default=None),
click.option("--no-dupes", is_flag=True, default=None),
click.option("--search-existing", is_flag=True, default=None),
click.option("--skip", default=None, multiple=True),
click.option("--skip-domain", default=None, multiple=True),
click.option("--skip-subreddit", default=None, multiple=True),
click.option("--min-score", type=int, default=None),
click.option("--max-score", type=int, default=None),
click.option("--min-score-ratio", type=float, default=None),
click.option("--max-score-ratio", type=float, default=None),
]
_archiver_options = [
click.option('--all-comments', is_flag=True, default=None),
click.option('--comment-context', is_flag=True, default=None),
click.option('-f', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None),
click.option("--all-comments", is_flag=True, default=None),
click.option("--comment-context", is_flag=True, default=None),
click.option("-f", "--format", type=click.Choice(("xml", "json", "yaml")), default=None),
]
@@ -68,6 +69,7 @@ def _add_options(opts: list):
for opt in opts:
func = opt(func)
return func
return wrap
@@ -76,7 +78,7 @@ def cli():
pass
@cli.command('download')
@cli.command("download")
@_add_options(_common_options)
@_add_options(_downloader_options)
@click.pass_context
@@ -88,13 +90,13 @@ def cli_download(context: click.Context, **_):
reddit_downloader = RedditDownloader(config)
reddit_downloader.download()
except Exception:
logger.exception('Downloader exited unexpectedly')
logger.exception("Downloader exited unexpectedly")
raise
else:
logger.info('Program complete')
logger.info("Program complete")
@cli.command('archive')
@cli.command("archive")
@_add_options(_common_options)
@_add_options(_archiver_options)
@click.pass_context
@@ -106,13 +108,13 @@ def cli_archive(context: click.Context, **_):
reddit_archiver = Archiver(config)
reddit_archiver.download()
except Exception:
logger.exception('Archiver exited unexpectedly')
logger.exception("Archiver exited unexpectedly")
raise
else:
logger.info('Program complete')
logger.info("Program complete")
@cli.command('clone')
@cli.command("clone")
@_add_options(_common_options)
@_add_options(_archiver_options)
@_add_options(_downloader_options)
@@ -125,10 +127,10 @@ def cli_clone(context: click.Context, **_):
reddit_scraper = RedditCloner(config)
reddit_scraper.download()
except Exception:
logger.exception('Scraper exited unexpectedly')
logger.exception("Scraper exited unexpectedly")
raise
else:
logger.info('Program complete')
logger.info("Program complete")
def setup_logging(verbosity: int):
@@ -141,7 +143,7 @@ def setup_logging(verbosity: int):
stream = logging.StreamHandler(sys.stdout)
stream.addFilter(StreamExceptionFilter())
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s")
stream.setFormatter(formatter)
logger.addHandler(stream)
@@ -151,10 +153,10 @@ def setup_logging(verbosity: int):
stream.setLevel(logging.DEBUG)
else:
stream.setLevel(9)
logging.getLogger('praw').setLevel(logging.CRITICAL)
logging.getLogger('prawcore').setLevel(logging.CRITICAL)
logging.getLogger('urllib3').setLevel(logging.CRITICAL)
logging.getLogger("praw").setLevel(logging.CRITICAL)
logging.getLogger("prawcore").setLevel(logging.CRITICAL)
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
if __name__ == '__main__':
if __name__ == "__main__":
cli()

View File

@@ -19,21 +19,21 @@ class BaseArchiveEntry(ABC):
@staticmethod
def _convert_comment_to_dict(in_comment: Comment) -> dict:
out_dict = {
'author': in_comment.author.name if in_comment.author else 'DELETED',
'id': in_comment.id,
'score': in_comment.score,
'subreddit': in_comment.subreddit.display_name,
'author_flair': in_comment.author_flair_text,
'submission': in_comment.submission.id,
'stickied': in_comment.stickied,
'body': in_comment.body,
'is_submitter': in_comment.is_submitter,
'distinguished': in_comment.distinguished,
'created_utc': in_comment.created_utc,
'parent_id': in_comment.parent_id,
'replies': [],
"author": in_comment.author.name if in_comment.author else "DELETED",
"id": in_comment.id,
"score": in_comment.score,
"subreddit": in_comment.subreddit.display_name,
"author_flair": in_comment.author_flair_text,
"submission": in_comment.submission.id,
"stickied": in_comment.stickied,
"body": in_comment.body,
"is_submitter": in_comment.is_submitter,
"distinguished": in_comment.distinguished,
"created_utc": in_comment.created_utc,
"parent_id": in_comment.parent_id,
"replies": [],
}
in_comment.replies.replace_more(limit=None)
for reply in in_comment.replies:
out_dict['replies'].append(BaseArchiveEntry._convert_comment_to_dict(reply))
out_dict["replies"].append(BaseArchiveEntry._convert_comment_to_dict(reply))
return out_dict

View File

@@ -17,5 +17,5 @@ class CommentArchiveEntry(BaseArchiveEntry):
def compile(self) -> dict:
self.source.refresh()
self.post_details = self._convert_comment_to_dict(self.source)
self.post_details['submission_title'] = self.source.submission.title
self.post_details["submission_title"] = self.source.submission.title
return self.post_details

View File

@@ -18,32 +18,32 @@ class SubmissionArchiveEntry(BaseArchiveEntry):
comments = self._get_comments()
self._get_post_details()
out = self.post_details
out['comments'] = comments
out["comments"] = comments
return out
def _get_post_details(self):
self.post_details = {
'title': self.source.title,
'name': self.source.name,
'url': self.source.url,
'selftext': self.source.selftext,
'score': self.source.score,
'upvote_ratio': self.source.upvote_ratio,
'permalink': self.source.permalink,
'id': self.source.id,
'author': self.source.author.name if self.source.author else 'DELETED',
'link_flair_text': self.source.link_flair_text,
'num_comments': self.source.num_comments,
'over_18': self.source.over_18,
'spoiler': self.source.spoiler,
'pinned': self.source.pinned,
'locked': self.source.locked,
'distinguished': self.source.distinguished,
'created_utc': self.source.created_utc,
"title": self.source.title,
"name": self.source.name,
"url": self.source.url,
"selftext": self.source.selftext,
"score": self.source.score,
"upvote_ratio": self.source.upvote_ratio,
"permalink": self.source.permalink,
"id": self.source.id,
"author": self.source.author.name if self.source.author else "DELETED",
"link_flair_text": self.source.link_flair_text,
"num_comments": self.source.num_comments,
"over_18": self.source.over_18,
"spoiler": self.source.spoiler,
"pinned": self.source.pinned,
"locked": self.source.locked,
"distinguished": self.source.distinguished,
"created_utc": self.source.created_utc,
}
def _get_comments(self) -> list[dict]:
logger.debug(f'Retrieving full comment tree for submission {self.source.id}')
logger.debug(f"Retrieving full comment tree for submission {self.source.id}")
comments = []
self.source.comments.replace_more(limit=None)
for top_level_comment in self.source.comments:

View File

@@ -30,26 +30,28 @@ class Archiver(RedditConnector):
for generator in self.reddit_lists:
for submission in generator:
try:
if (submission.author and submission.author.name in self.args.ignore_user) or \
(submission.author is None and 'DELETED' in self.args.ignore_user):
if (submission.author and submission.author.name in self.args.ignore_user) or (
submission.author is None and "DELETED" in self.args.ignore_user
):
logger.debug(
f'Submission {submission.id} in {submission.subreddit.display_name} skipped'
f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user')
f"Submission {submission.id} in {submission.subreddit.display_name} skipped"
f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user'
)
continue
if submission.id in self.excluded_submission_ids:
logger.debug(f'Object {submission.id} in exclusion list, skipping')
logger.debug(f"Object {submission.id} in exclusion list, skipping")
continue
logger.debug(f'Attempting to archive submission {submission.id}')
logger.debug(f"Attempting to archive submission {submission.id}")
self.write_entry(submission)
except prawcore.PrawcoreException as e:
logger.error(f'Submission {submission.id} failed to be archived due to a PRAW exception: {e}')
logger.error(f"Submission {submission.id} failed to be archived due to a PRAW exception: {e}")
def get_submissions_from_link(self) -> list[list[praw.models.Submission]]:
supplied_submissions = []
for sub_id in self.args.link:
if len(sub_id) == 6:
supplied_submissions.append(self.reddit_instance.submission(id=sub_id))
elif re.match(r'^\w{7}$', sub_id):
elif re.match(r"^\w{7}$", sub_id):
supplied_submissions.append(self.reddit_instance.comment(id=sub_id))
else:
supplied_submissions.append(self.reddit_instance.submission(url=sub_id))
@@ -60,7 +62,7 @@ class Archiver(RedditConnector):
if self.args.user and self.args.all_comments:
sort = self.determine_sort_function()
for user in self.args.user:
logger.debug(f'Retrieving comments of user {user}')
logger.debug(f"Retrieving comments of user {user}")
results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit))
return results
@@ -71,43 +73,44 @@ class Archiver(RedditConnector):
elif isinstance(praw_item, praw.models.Comment):
return CommentArchiveEntry(praw_item)
else:
raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}')
raise ArchiverError(f"Factory failed to classify item of type {type(praw_item).__name__}")
def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]):
if self.args.comment_context and isinstance(praw_item, praw.models.Comment):
logger.debug(f'Converting comment {praw_item.id} to submission {praw_item.submission.id}')
logger.debug(f"Converting comment {praw_item.id} to submission {praw_item.submission.id}")
praw_item = praw_item.submission
archive_entry = self._pull_lever_entry_factory(praw_item)
if self.args.format == 'json':
if self.args.format == "json":
self._write_entry_json(archive_entry)
elif self.args.format == 'xml':
elif self.args.format == "xml":
self._write_entry_xml(archive_entry)
elif self.args.format == 'yaml':
elif self.args.format == "yaml":
self._write_entry_yaml(archive_entry)
else:
raise ArchiverError(f'Unknown format {self.args.format} given')
logger.info(f'Record for entry item {praw_item.id} written to disk')
raise ArchiverError(f"Unknown format {self.args.format} given")
logger.info(f"Record for entry item {praw_item.id} written to disk")
def _write_entry_json(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', lambda: None, '.json')
resource = Resource(entry.source, "", lambda: None, ".json")
content = json.dumps(entry.compile())
self._write_content_to_disk(resource, content)
def _write_entry_xml(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', lambda: None, '.xml')
content = dict2xml.dict2xml(entry.compile(), wrap='root')
resource = Resource(entry.source, "", lambda: None, ".xml")
content = dict2xml.dict2xml(entry.compile(), wrap="root")
self._write_content_to_disk(resource, content)
def _write_entry_yaml(self, entry: BaseArchiveEntry):
resource = Resource(entry.source, '', lambda: None, '.yaml')
resource = Resource(entry.source, "", lambda: None, ".yaml")
content = yaml.dump(entry.compile())
self._write_content_to_disk(resource, content)
def _write_content_to_disk(self, resource: Resource, content: str):
file_path = self.file_name_formatter.format_path(resource, self.download_directory)
file_path.parent.mkdir(exist_ok=True, parents=True)
with open(file_path, 'w', encoding="utf-8") as file:
with open(file_path, "w", encoding="utf-8") as file:
logger.debug(
f'Writing entry {resource.source_submission.id} to file in {resource.extension[1:].upper()}'
f' format at {file_path}')
f"Writing entry {resource.source_submission.id} to file in {resource.extension[1:].upper()}"
f" format at {file_path}"
)
file.write(content)

View File

@@ -23,4 +23,4 @@ class RedditCloner(RedditDownloader, Archiver):
self._download_submission(submission)
self.write_entry(submission)
except prawcore.PrawcoreException as e:
logger.error(f'Submission {submission.id} failed to be cloned due to a PRAW exception: {e}')
logger.error(f"Submission {submission.id} failed to be cloned due to a PRAW exception: {e}")

View File

@@ -1,28 +1,29 @@
#!/usr/bin/env python3
# coding=utf-8
import logging
from argparse import Namespace
from pathlib import Path
from typing import Optional
import logging
import click
import yaml
logger = logging.getLogger(__name__)
class Configuration(Namespace):
def __init__(self):
super(Configuration, self).__init__()
self.authenticate = False
self.config = None
self.opts: Optional[str] = None
self.directory: str = '.'
self.directory: str = "."
self.disable_module: list[str] = []
self.exclude_id = []
self.exclude_id_file = []
self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}'
self.folder_scheme: str = '{SUBREDDIT}'
self.file_scheme: str = "{REDDITOR}_{TITLE}_{POSTID}"
self.folder_scheme: str = "{SUBREDDIT}"
self.ignore_user = []
self.include_id_file = []
self.limit: Optional[int] = None
@@ -42,11 +43,11 @@ class Configuration(Namespace):
self.max_score = None
self.min_score_ratio = None
self.max_score_ratio = None
self.sort: str = 'hot'
self.sort: str = "hot"
self.submitted: bool = False
self.subscribed: bool = False
self.subreddit: list[str] = []
self.time: str = 'all'
self.time: str = "all"
self.time_format = None
self.upvoted: bool = False
self.user: list[str] = []
@@ -54,15 +55,15 @@ class Configuration(Namespace):
# Archiver-specific options
self.all_comments = False
self.format = 'json'
self.format = "json"
self.comment_context: bool = False
def process_click_arguments(self, context: click.Context):
if context.params.get('opts') is not None:
self.parse_yaml_options(context.params['opts'])
if context.params.get("opts") is not None:
self.parse_yaml_options(context.params["opts"])
for arg_key in context.params.keys():
if not hasattr(self, arg_key):
logger.warning(f'Ignoring an unknown CLI argument: {arg_key}')
logger.warning(f"Ignoring an unknown CLI argument: {arg_key}")
continue
val = context.params[arg_key]
if val is None or val == ():
@@ -73,16 +74,16 @@ class Configuration(Namespace):
def parse_yaml_options(self, file_path: str):
yaml_file_loc = Path(file_path)
if not yaml_file_loc.exists():
logger.error(f'No YAML file found at {yaml_file_loc}')
logger.error(f"No YAML file found at {yaml_file_loc}")
return
with yaml_file_loc.open() as file:
try:
opts = yaml.load(file, Loader=yaml.FullLoader)
except yaml.YAMLError as e:
logger.error(f'Could not parse YAML options file: {e}')
logger.error(f"Could not parse YAML options file: {e}")
return
for arg_key, val in opts.items():
if not hasattr(self, arg_key):
logger.warning(f'Ignoring an unknown YAML argument: {arg_key}')
logger.warning(f"Ignoring an unknown YAML argument: {arg_key}")
continue
setattr(self, arg_key, val)

View File

@@ -41,18 +41,18 @@ class RedditTypes:
TOP = auto()
class TimeType(Enum):
ALL = 'all'
DAY = 'day'
HOUR = 'hour'
MONTH = 'month'
WEEK = 'week'
YEAR = 'year'
ALL = "all"
DAY = "day"
HOUR = "hour"
MONTH = "month"
WEEK = "week"
YEAR = "year"
class RedditConnector(metaclass=ABCMeta):
def __init__(self, args: Configuration):
self.args = args
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR')
self.config_directories = appdirs.AppDirs("bdfr", "BDFR")
self.run_time = datetime.now().isoformat()
self._setup_internal_objects()
@@ -68,13 +68,13 @@ class RedditConnector(metaclass=ABCMeta):
self.parse_disabled_modules()
self.download_filter = self.create_download_filter()
logger.log(9, 'Created download filter')
logger.log(9, "Created download filter")
self.time_filter = self.create_time_filter()
logger.log(9, 'Created time filter')
logger.log(9, "Created time filter")
self.sort_filter = self.create_sort_filter()
logger.log(9, 'Created sort filter')
logger.log(9, "Created sort filter")
self.file_name_formatter = self.create_file_name_formatter()
logger.log(9, 'Create file name formatter')
logger.log(9, "Create file name formatter")
self.create_reddit_instance()
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
@@ -88,7 +88,7 @@ class RedditConnector(metaclass=ABCMeta):
self.master_hash_list = {}
self.authenticator = self.create_authenticator()
logger.log(9, 'Created site authenticator')
logger.log(9, "Created site authenticator")
self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit)
self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit}
@@ -96,18 +96,18 @@ class RedditConnector(metaclass=ABCMeta):
def read_config(self):
"""Read any cfg values that need to be processed"""
if self.args.max_wait_time is None:
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time', fallback=120)
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds')
self.args.max_wait_time = self.cfg_parser.getint("DEFAULT", "max_wait_time", fallback=120)
logger.debug(f"Setting maximum download wait time to {self.args.max_wait_time} seconds")
if self.args.time_format is None:
option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO')
if re.match(r'^[\s\'\"]*$', option):
option = 'ISO'
logger.debug(f'Setting datetime format string to {option}')
option = self.cfg_parser.get("DEFAULT", "time_format", fallback="ISO")
if re.match(r"^[\s\'\"]*$", option):
option = "ISO"
logger.debug(f"Setting datetime format string to {option}")
self.args.time_format = option
if not self.args.disable_module:
self.args.disable_module = [self.cfg_parser.get('DEFAULT', 'disabled_modules', fallback='')]
self.args.disable_module = [self.cfg_parser.get("DEFAULT", "disabled_modules", fallback="")]
# Update config on disk
with open(self.config_location, 'w') as file:
with open(self.config_location, "w") as file:
self.cfg_parser.write(file)
def parse_disabled_modules(self):
@@ -119,48 +119,48 @@ class RedditConnector(metaclass=ABCMeta):
def create_reddit_instance(self):
if self.args.authenticate:
logger.debug('Using authenticated Reddit instance')
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
logger.log(9, 'Commencing OAuth2 authentication')
scopes = self.cfg_parser.get('DEFAULT', 'scopes', fallback='identity, history, read, save')
logger.debug("Using authenticated Reddit instance")
if not self.cfg_parser.has_option("DEFAULT", "user_token"):
logger.log(9, "Commencing OAuth2 authentication")
scopes = self.cfg_parser.get("DEFAULT", "scopes", fallback="identity, history, read, save")
scopes = OAuth2Authenticator.split_scopes(scopes)
oauth2_authenticator = OAuth2Authenticator(
scopes,
self.cfg_parser.get('DEFAULT', 'client_id'),
self.cfg_parser.get('DEFAULT', 'client_secret'),
self.cfg_parser.get("DEFAULT", "client_id"),
self.cfg_parser.get("DEFAULT", "client_secret"),
)
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token
with open(self.config_location, 'w') as file:
self.cfg_parser["DEFAULT"]["user_token"] = token
with open(self.config_location, "w") as file:
self.cfg_parser.write(file, True)
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
self.authenticated = True
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
client_id=self.cfg_parser.get("DEFAULT", "client_id"),
client_secret=self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(),
token_manager=token_manager,
)
else:
logger.debug('Using unauthenticated Reddit instance')
logger.debug("Using unauthenticated Reddit instance")
self.authenticated = False
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
client_id=self.cfg_parser.get("DEFAULT", "client_id"),
client_secret=self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(),
)
def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
master_list = []
master_list.extend(self.get_subreddits())
logger.log(9, 'Retrieved subreddits')
logger.log(9, "Retrieved subreddits")
master_list.extend(self.get_multireddits())
logger.log(9, 'Retrieved multireddits')
logger.log(9, "Retrieved multireddits")
master_list.extend(self.get_user_data())
logger.log(9, 'Retrieved user data')
logger.log(9, "Retrieved user data")
master_list.extend(self.get_submissions_from_link())
logger.log(9, 'Retrieved submissions for given links')
logger.log(9, "Retrieved submissions for given links")
return master_list
def determine_directories(self):
@@ -178,37 +178,37 @@ class RedditConnector(metaclass=ABCMeta):
self.config_location = cfg_path
return
possible_paths = [
Path('./config.cfg'),
Path('./default_config.cfg'),
Path(self.config_directory, 'config.cfg'),
Path(self.config_directory, 'default_config.cfg'),
Path("./config.cfg"),
Path("./default_config.cfg"),
Path(self.config_directory, "config.cfg"),
Path(self.config_directory, "default_config.cfg"),
]
self.config_location = None
for path in possible_paths:
if path.resolve().expanduser().exists():
self.config_location = path
logger.debug(f'Loading configuration from {path}')
logger.debug(f"Loading configuration from {path}")
break
if not self.config_location:
with importlib.resources.path('bdfr', 'default_config.cfg') as path:
with importlib.resources.path("bdfr", "default_config.cfg") as path:
self.config_location = path
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg'))
shutil.copy(self.config_location, Path(self.config_directory, "default_config.cfg"))
if not self.config_location:
raise errors.BulkDownloaderException('Could not find a configuration file to load')
raise errors.BulkDownloaderException("Could not find a configuration file to load")
self.cfg_parser.read(self.config_location)
def create_file_logger(self):
main_logger = logging.getLogger()
if self.args.log is None:
log_path = Path(self.config_directory, 'log_output.txt')
log_path = Path(self.config_directory, "log_output.txt")
else:
log_path = Path(self.args.log).resolve().expanduser()
if not log_path.parent.exists():
raise errors.BulkDownloaderException(f'Designated location for logfile does not exist')
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3)
raise errors.BulkDownloaderException(f"Designated location for logfile does not exist")
backup_count = self.cfg_parser.getint("DEFAULT", "backup_log_count", fallback=3)
file_handler = logging.handlers.RotatingFileHandler(
log_path,
mode='a',
mode="a",
backupCount=backup_count,
)
if log_path.exists():
@@ -216,10 +216,11 @@ class RedditConnector(metaclass=ABCMeta):
file_handler.doRollover()
except PermissionError:
logger.critical(
'Cannot rollover logfile, make sure this is the only '
'BDFR process or specify alternate logfile location')
"Cannot rollover logfile, make sure this is the only "
"BDFR process or specify alternate logfile location"
)
raise
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s")
file_handler.setFormatter(formatter)
file_handler.setLevel(0)
@@ -227,16 +228,16 @@ class RedditConnector(metaclass=ABCMeta):
@staticmethod
def sanitise_subreddit_name(subreddit: str) -> str:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$')
pattern = re.compile(r"^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$")
match = re.match(pattern, subreddit)
if not match:
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
raise errors.BulkDownloaderException(f"Could not find subreddit name in string {subreddit}")
return match.group(1)
@staticmethod
def split_args_input(entries: list[str]) -> set[str]:
all_entries = []
split_pattern = re.compile(r'[,;]\s?')
split_pattern = re.compile(r"[,;]\s?")
for entry in entries:
results = re.split(split_pattern, entry)
all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results])
@@ -251,13 +252,13 @@ class RedditConnector(metaclass=ABCMeta):
subscribed_subreddits = list(self.reddit_instance.user.subreddits(limit=None))
subscribed_subreddits = {s.display_name for s in subscribed_subreddits}
except prawcore.InsufficientScope:
logger.error('BDFR has insufficient scope to access subreddit lists')
logger.error("BDFR has insufficient scope to access subreddit lists")
else:
logger.error('Cannot find subscribed subreddits without an authenticated instance')
logger.error("Cannot find subscribed subreddits without an authenticated instance")
if self.args.subreddit or subscribed_subreddits:
for reddit in self.split_args_input(self.args.subreddit) | subscribed_subreddits:
if reddit == 'friends' and self.authenticated is False:
logger.error('Cannot read friends subreddit without an authenticated instance')
if reddit == "friends" and self.authenticated is False:
logger.error("Cannot read friends subreddit without an authenticated instance")
continue
try:
reddit = self.reddit_instance.subreddit(reddit)
@@ -267,26 +268,29 @@ class RedditConnector(metaclass=ABCMeta):
logger.error(e)
continue
if self.args.search:
out.append(reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
time_filter=self.time_filter.value,
))
out.append(
reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
time_filter=self.time_filter.value,
)
)
logger.debug(
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"'
)
else:
out.append(self.create_filtered_listing_generator(reddit))
logger.debug(f'Added submissions from subreddit {reddit}')
logger.debug(f"Added submissions from subreddit {reddit}")
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
logger.error(f"Failed to get submissions for subreddit {reddit}: {e}")
return out
def resolve_user_name(self, in_name: str) -> str:
if in_name == 'me':
if in_name == "me":
if self.authenticated:
resolved_name = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {resolved_name}')
logger.log(9, f"Resolved user to {resolved_name}")
return resolved_name
else:
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
@@ -318,7 +322,7 @@ class RedditConnector(metaclass=ABCMeta):
def get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
if len(self.args.user) != 1:
logger.error(f'Only 1 user can be supplied when retrieving from multireddits')
logger.error(f"Only 1 user can be supplied when retrieving from multireddits")
return []
out = []
for multi in self.split_args_input(self.args.multireddit):
@@ -327,9 +331,9 @@ class RedditConnector(metaclass=ABCMeta):
if not multi.subreddits:
raise errors.BulkDownloaderException
out.append(self.create_filtered_listing_generator(multi))
logger.debug(f'Added submissions from multireddit {multi}')
logger.debug(f"Added submissions from multireddit {multi}")
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
logger.error(f"Failed to get submissions for multireddit {multi}: {e}")
return out
else:
return []
@@ -344,7 +348,7 @@ class RedditConnector(metaclass=ABCMeta):
def get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if not self.args.user:
logger.warning('At least one user must be supplied to download user data')
logger.warning("At least one user must be supplied to download user data")
return []
generators = []
for user in self.args.user:
@@ -354,18 +358,20 @@ class RedditConnector(metaclass=ABCMeta):
logger.error(e)
continue
if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
generators.append(self.create_filtered_listing_generator(
self.reddit_instance.redditor(user).submissions,
))
logger.debug(f"Retrieving submitted posts of user {self.args.user}")
generators.append(
self.create_filtered_listing_generator(
self.reddit_instance.redditor(user).submissions,
)
)
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication')
logger.warning("Accessing user lists requires authentication")
else:
if self.args.upvoted:
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
logger.debug(f"Retrieving upvoted posts of user {self.args.user}")
generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit))
if self.args.saved:
logger.debug(f'Retrieving saved posts of user {self.args.user}')
logger.debug(f"Retrieving saved posts of user {self.args.user}")
generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit))
return generators
else:
@@ -377,10 +383,10 @@ class RedditConnector(metaclass=ABCMeta):
if user.id:
return
except prawcore.exceptions.NotFound:
raise errors.BulkDownloaderException(f'Could not find user {name}')
raise errors.BulkDownloaderException(f"Could not find user {name}")
except AttributeError:
if hasattr(user, 'is_suspended'):
raise errors.BulkDownloaderException(f'User {name} is banned')
if hasattr(user, "is_suspended"):
raise errors.BulkDownloaderException(f"User {name} is banned")
def create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format)
@@ -409,7 +415,7 @@ class RedditConnector(metaclass=ABCMeta):
@staticmethod
def check_subreddit_status(subreddit: praw.models.Subreddit):
if subreddit.display_name in ('all', 'friends'):
if subreddit.display_name in ("all", "friends"):
return
try:
assert subreddit.id
@@ -418,7 +424,7 @@ class RedditConnector(metaclass=ABCMeta):
except prawcore.Redirect:
raise errors.BulkDownloaderException(f"Source {subreddit.display_name} does not exist")
except prawcore.Forbidden:
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped')
raise errors.BulkDownloaderException(f"Source {subreddit.display_name} is private and cannot be scraped")
@staticmethod
def read_id_files(file_locations: list[str]) -> set[str]:
@@ -426,9 +432,9 @@ class RedditConnector(metaclass=ABCMeta):
for id_file in file_locations:
id_file = Path(id_file).resolve().expanduser()
if not id_file.exists():
logger.warning(f'ID file at {id_file} does not exist')
logger.warning(f"ID file at {id_file} does not exist")
continue
with id_file.open('r') as file:
with id_file.open("r") as file:
for line in file:
out.append(line.strip())
return set(out)

View File

@@ -33,8 +33,8 @@ class DownloadFilter:
def _check_extension(self, resource_extension: str) -> bool:
if not self.excluded_extensions:
return True
combined_extensions = '|'.join(self.excluded_extensions)
pattern = re.compile(r'.*({})$'.format(combined_extensions))
combined_extensions = "|".join(self.excluded_extensions)
pattern = re.compile(r".*({})$".format(combined_extensions))
if re.match(pattern, resource_extension):
logger.log(9, f'Url "{resource_extension}" matched with "{pattern}"')
return False
@@ -44,8 +44,8 @@ class DownloadFilter:
def _check_domain(self, url: str) -> bool:
if not self.excluded_domains:
return True
combined_domains = '|'.join(self.excluded_domains)
pattern = re.compile(r'https?://.*({}).*'.format(combined_domains))
combined_domains = "|".join(self.excluded_domains)
pattern = re.compile(r"https?://.*({}).*".format(combined_domains))
if re.match(pattern, url):
logger.log(9, f'Url "{url}" matched with "{pattern}"')
return False

View File

@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
def _calc_hash(existing_file: Path):
chunk_size = 1024 * 1024
md5_hash = hashlib.md5()
with existing_file.open('rb') as file:
with existing_file.open("rb") as file:
chunk = file.read(chunk_size)
while chunk:
md5_hash.update(chunk)
@@ -46,28 +46,32 @@ class RedditDownloader(RedditConnector):
try:
self._download_submission(submission)
except prawcore.PrawcoreException as e:
logger.error(f'Submission {submission.id} failed to download due to a PRAW exception: {e}')
logger.error(f"Submission {submission.id} failed to download due to a PRAW exception: {e}")
def _download_submission(self, submission: praw.models.Submission):
if submission.id in self.excluded_submission_ids:
logger.debug(f'Object {submission.id} in exclusion list, skipping')
logger.debug(f"Object {submission.id} in exclusion list, skipping")
return
elif submission.subreddit.display_name.lower() in self.args.skip_subreddit:
logger.debug(f'Submission {submission.id} in {submission.subreddit.display_name} in skip list')
logger.debug(f"Submission {submission.id} in {submission.subreddit.display_name} in skip list")
return
elif (submission.author and submission.author.name in self.args.ignore_user) or \
(submission.author is None and 'DELETED' in self.args.ignore_user):
elif (submission.author and submission.author.name in self.args.ignore_user) or (
submission.author is None and "DELETED" in self.args.ignore_user
):
logger.debug(
f'Submission {submission.id} in {submission.subreddit.display_name} skipped'
f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user')
f"Submission {submission.id} in {submission.subreddit.display_name} skipped"
f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user'
)
return
elif self.args.min_score and submission.score < self.args.min_score:
logger.debug(
f"Submission {submission.id} filtered due to score {submission.score} < [{self.args.min_score}]")
f"Submission {submission.id} filtered due to score {submission.score} < [{self.args.min_score}]"
)
return
elif self.args.max_score and self.args.max_score < submission.score:
logger.debug(
f"Submission {submission.id} filtered due to score {submission.score} > [{self.args.max_score}]")
f"Submission {submission.id} filtered due to score {submission.score} > [{self.args.max_score}]"
)
return
elif (self.args.min_score_ratio and submission.upvote_ratio < self.args.min_score_ratio) or (
self.args.max_score_ratio and self.args.max_score_ratio < submission.upvote_ratio
@@ -75,47 +79,48 @@ class RedditDownloader(RedditConnector):
logger.debug(f"Submission {submission.id} filtered due to score ratio ({submission.upvote_ratio})")
return
elif not isinstance(submission, praw.models.Submission):
logger.warning(f'{submission.id} is not a submission')
logger.warning(f"{submission.id} is not a submission")
return
elif not self.download_filter.check_url(submission.url):
logger.debug(f'Submission {submission.id} filtered due to URL {submission.url}')
logger.debug(f"Submission {submission.id} filtered due to URL {submission.url}")
return
logger.debug(f'Attempting to download submission {submission.id}')
logger.debug(f"Attempting to download submission {submission.id}")
try:
downloader_class = DownloadFactory.pull_lever(submission.url)
downloader = downloader_class(submission)
logger.debug(f'Using {downloader_class.__name__} with url {submission.url}')
logger.debug(f"Using {downloader_class.__name__} with url {submission.url}")
except errors.NotADownloadableLinkError as e:
logger.error(f'Could not download submission {submission.id}: {e}')
logger.error(f"Could not download submission {submission.id}: {e}")
return
if downloader_class.__name__.lower() in self.args.disable_module:
logger.debug(f'Submission {submission.id} skipped due to disabled module {downloader_class.__name__}')
logger.debug(f"Submission {submission.id} skipped due to disabled module {downloader_class.__name__}")
return
try:
content = downloader.find_resources(self.authenticator)
except errors.SiteDownloaderError as e:
logger.error(f'Site {downloader_class.__name__} failed to download submission {submission.id}: {e}')
logger.error(f"Site {downloader_class.__name__} failed to download submission {submission.id}: {e}")
return
for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory):
if destination.exists():
logger.debug(f'File {destination} from submission {submission.id} already exists, continuing')
logger.debug(f"File {destination} from submission {submission.id} already exists, continuing")
continue
elif not self.download_filter.check_resource(res):
logger.debug(f'Download filter removed {submission.id} file with URL {submission.url}')
logger.debug(f"Download filter removed {submission.id} file with URL {submission.url}")
continue
try:
res.download({'max_wait_time': self.args.max_wait_time})
res.download({"max_wait_time": self.args.max_wait_time})
except errors.BulkDownloaderException as e:
logger.error(f'Failed to download resource {res.url} in submission {submission.id} '
f'with downloader {downloader_class.__name__}: {e}')
logger.error(
f"Failed to download resource {res.url} in submission {submission.id} "
f"with downloader {downloader_class.__name__}: {e}"
)
return
resource_hash = res.hash.hexdigest()
destination.parent.mkdir(parents=True, exist_ok=True)
if resource_hash in self.master_hash_list:
if self.args.no_dupes:
logger.info(
f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere')
logger.info(f"Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere")
return
elif self.args.make_hard_links:
try:
@@ -123,29 +128,30 @@ class RedditDownloader(RedditConnector):
except AttributeError:
self.master_hash_list[resource_hash].link_to(destination)
logger.info(
f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}'
f' in submission {submission.id}')
f"Hard link made linking {destination} to {self.master_hash_list[resource_hash]}"
f" in submission {submission.id}"
)
return
try:
with destination.open('wb') as file:
with destination.open("wb") as file:
file.write(res.content)
logger.debug(f'Written file to {destination}')
logger.debug(f"Written file to {destination}")
except OSError as e:
logger.exception(e)
logger.error(f'Failed to write file in submission {submission.id} to {destination}: {e}')
logger.error(f"Failed to write file in submission {submission.id} to {destination}: {e}")
return
creation_time = time.mktime(datetime.fromtimestamp(submission.created_utc).timetuple())
os.utime(destination, (creation_time, creation_time))
self.master_hash_list[resource_hash] = destination
logger.debug(f'Hash added to master list: {resource_hash}')
logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}')
logger.debug(f"Hash added to master list: {resource_hash}")
logger.info(f"Downloaded submission {submission.id} from {submission.subreddit.display_name}")
@staticmethod
def scan_existing_files(directory: Path) -> dict[str, Path]:
files = []
for (dirpath, dirnames, filenames) in os.walk(directory):
files.extend([Path(dirpath, file) for file in filenames])
logger.info(f'Calculating hashes for {len(files)} files')
logger.info(f"Calculating hashes for {len(files)} files")
pool = Pool(15)
results = pool.map(_calc_hash, files)

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env
class BulkDownloaderException(Exception):
pass

View File

@@ -18,20 +18,20 @@ logger = logging.getLogger(__name__)
class FileNameFormatter:
key_terms = (
'date',
'flair',
'postid',
'redditor',
'subreddit',
'title',
'upvotes',
"date",
"flair",
"postid",
"redditor",
"subreddit",
"title",
"upvotes",
)
def __init__(self, file_format_string: str, directory_format_string: str, time_format_string: str):
if not self.validate_string(file_format_string):
raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string')
self.file_format_string = file_format_string
self.directory_format_string: list[str] = directory_format_string.split('/')
self.directory_format_string: list[str] = directory_format_string.split("/")
self.time_format_string = time_format_string
def _format_name(self, submission: Union[Comment, Submission], format_string: str) -> str:
@@ -40,108 +40,111 @@ class FileNameFormatter:
elif isinstance(submission, Comment):
attributes = self._generate_name_dict_from_comment(submission)
else:
raise BulkDownloaderException(f'Cannot name object {type(submission).__name__}')
raise BulkDownloaderException(f"Cannot name object {type(submission).__name__}")
result = format_string
for key in attributes.keys():
if re.search(fr'(?i).*{{{key}}}.*', result):
key_value = str(attributes.get(key, 'unknown'))
if re.search(rf"(?i).*{{{key}}}.*", result):
key_value = str(attributes.get(key, "unknown"))
key_value = FileNameFormatter._convert_unicode_escapes(key_value)
key_value = key_value.replace('\\', '\\\\')
result = re.sub(fr'(?i){{{key}}}', key_value, result)
key_value = key_value.replace("\\", "\\\\")
result = re.sub(rf"(?i){{{key}}}", key_value, result)
result = result.replace('/', '')
result = result.replace("/", "")
if platform.system() == 'Windows':
if platform.system() == "Windows":
result = FileNameFormatter._format_for_windows(result)
return result
@staticmethod
def _convert_unicode_escapes(in_string: str) -> str:
pattern = re.compile(r'(\\u\d{4})')
pattern = re.compile(r"(\\u\d{4})")
matches = re.search(pattern, in_string)
if matches:
for match in matches.groups():
converted_match = bytes(match, 'utf-8').decode('unicode-escape')
converted_match = bytes(match, "utf-8").decode("unicode-escape")
in_string = in_string.replace(match, converted_match)
return in_string
def _generate_name_dict_from_submission(self, submission: Submission) -> dict:
submission_attributes = {
'title': submission.title,
'subreddit': submission.subreddit.display_name,
'redditor': submission.author.name if submission.author else 'DELETED',
'postid': submission.id,
'upvotes': submission.score,
'flair': submission.link_flair_text,
'date': self._convert_timestamp(submission.created_utc),
"title": submission.title,
"subreddit": submission.subreddit.display_name,
"redditor": submission.author.name if submission.author else "DELETED",
"postid": submission.id,
"upvotes": submission.score,
"flair": submission.link_flair_text,
"date": self._convert_timestamp(submission.created_utc),
}
return submission_attributes
def _convert_timestamp(self, timestamp: float) -> str:
input_time = datetime.datetime.fromtimestamp(timestamp)
if self.time_format_string.upper().strip() == 'ISO':
if self.time_format_string.upper().strip() == "ISO":
return input_time.isoformat()
else:
return input_time.strftime(self.time_format_string)
def _generate_name_dict_from_comment(self, comment: Comment) -> dict:
comment_attributes = {
'title': comment.submission.title,
'subreddit': comment.subreddit.display_name,
'redditor': comment.author.name if comment.author else 'DELETED',
'postid': comment.id,
'upvotes': comment.score,
'flair': '',
'date': self._convert_timestamp(comment.created_utc),
"title": comment.submission.title,
"subreddit": comment.subreddit.display_name,
"redditor": comment.author.name if comment.author else "DELETED",
"postid": comment.id,
"upvotes": comment.score,
"flair": "",
"date": self._convert_timestamp(comment.created_utc),
}
return comment_attributes
def format_path(
self,
resource: Resource,
destination_directory: Path,
index: Optional[int] = None,
self,
resource: Resource,
destination_directory: Path,
index: Optional[int] = None,
) -> Path:
subfolder = Path(
destination_directory,
*[self._format_name(resource.source_submission, part) for part in self.directory_format_string],
)
index = f'_{index}' if index else ''
index = f"_{index}" if index else ""
if not resource.extension:
raise BulkDownloaderException(f'Resource from {resource.url} has no extension')
raise BulkDownloaderException(f"Resource from {resource.url} has no extension")
file_name = str(self._format_name(resource.source_submission, self.file_format_string))
file_name = re.sub(r'\n', ' ', file_name)
file_name = re.sub(r"\n", " ", file_name)
if not re.match(r'.*\.$', file_name) and not re.match(r'^\..*', resource.extension):
ending = index + '.' + resource.extension
if not re.match(r".*\.$", file_name) and not re.match(r"^\..*", resource.extension):
ending = index + "." + resource.extension
else:
ending = index + resource.extension
try:
file_path = self.limit_file_name_length(file_name, ending, subfolder)
except TypeError:
raise BulkDownloaderException(f'Could not determine path name: {subfolder}, {index}, {resource.extension}')
raise BulkDownloaderException(f"Could not determine path name: {subfolder}, {index}, {resource.extension}")
return file_path
@staticmethod
def limit_file_name_length(filename: str, ending: str, root: Path) -> Path:
root = root.resolve().expanduser()
possible_id = re.search(r'((?:_\w{6})?$)', filename)
possible_id = re.search(r"((?:_\w{6})?$)", filename)
if possible_id:
ending = possible_id.group(1) + ending
filename = filename[:possible_id.start()]
filename = filename[: possible_id.start()]
max_path = FileNameFormatter.find_max_path_length()
max_file_part_length_chars = 255 - len(ending)
max_file_part_length_bytes = 255 - len(ending.encode('utf-8'))
max_file_part_length_bytes = 255 - len(ending.encode("utf-8"))
max_path_length = max_path - len(ending) - len(str(root)) - 1
out = Path(root, filename + ending)
while any([len(filename) > max_file_part_length_chars,
len(filename.encode('utf-8')) > max_file_part_length_bytes,
len(str(out)) > max_path_length,
]):
while any(
[
len(filename) > max_file_part_length_chars,
len(filename.encode("utf-8")) > max_file_part_length_bytes,
len(str(out)) > max_path_length,
]
):
filename = filename[:-1]
out = Path(root, filename + ending)
@@ -150,44 +153,46 @@ class FileNameFormatter:
@staticmethod
def find_max_path_length() -> int:
try:
return int(subprocess.check_output(['getconf', 'PATH_MAX', '/']))
return int(subprocess.check_output(["getconf", "PATH_MAX", "/"]))
except (ValueError, subprocess.CalledProcessError, OSError):
if platform.system() == 'Windows':
if platform.system() == "Windows":
return 260
else:
return 4096
def format_resource_paths(
self,
resources: list[Resource],
destination_directory: Path,
self,
resources: list[Resource],
destination_directory: Path,
) -> list[tuple[Path, Resource]]:
out = []
if len(resources) == 1:
try:
out.append((self.format_path(resources[0], destination_directory, None), resources[0]))
except BulkDownloaderException as e:
logger.error(f'Could not generate file path for resource {resources[0].url}: {e}')
logger.exception('Could not generate file path')
logger.error(f"Could not generate file path for resource {resources[0].url}: {e}")
logger.exception("Could not generate file path")
else:
for i, res in enumerate(resources, start=1):
logger.log(9, f'Formatting filename with index {i}')
logger.log(9, f"Formatting filename with index {i}")
try:
out.append((self.format_path(res, destination_directory, i), res))
except BulkDownloaderException as e:
logger.error(f'Could not generate file path for resource {res.url}: {e}')
logger.exception('Could not generate file path')
logger.error(f"Could not generate file path for resource {res.url}: {e}")
logger.exception("Could not generate file path")
return out
@staticmethod
def validate_string(test_string: str) -> bool:
if not test_string:
return False
result = any([f'{{{key}}}' in test_string.lower() for key in FileNameFormatter.key_terms])
result = any([f"{{{key}}}" in test_string.lower() for key in FileNameFormatter.key_terms])
if result:
if 'POSTID' not in test_string:
logger.warning('Some files might not be downloaded due to name conflicts as filenames are'
' not guaranteed to be be unique without {POSTID}')
if "POSTID" not in test_string:
logger.warning(
"Some files might not be downloaded due to name conflicts as filenames are"
" not guaranteed to be be unique without {POSTID}"
)
return True
else:
return False
@@ -196,11 +201,11 @@ class FileNameFormatter:
def _format_for_windows(input_string: str) -> str:
invalid_characters = r'<>:"\/|?*'
for char in invalid_characters:
input_string = input_string.replace(char, '')
input_string = input_string.replace(char, "")
input_string = FileNameFormatter._strip_emojis(input_string)
return input_string
@staticmethod
def _strip_emojis(input_string: str) -> str:
result = input_string.encode('ascii', errors='ignore').decode('utf-8')
result = input_string.encode("ascii", errors="ignore").decode("utf-8")
return result

View File

@@ -17,7 +17,6 @@ logger = logging.getLogger(__name__)
class OAuth2Authenticator:
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
self._check_scopes(wanted_scopes)
self.scopes = wanted_scopes
@@ -26,39 +25,41 @@ class OAuth2Authenticator:
@staticmethod
def _check_scopes(wanted_scopes: set[str]):
response = requests.get('https://www.reddit.com/api/v1/scopes.json',
headers={'User-Agent': 'fetch-scopes test'})
response = requests.get(
"https://www.reddit.com/api/v1/scopes.json", headers={"User-Agent": "fetch-scopes test"}
)
known_scopes = [scope for scope, data in response.json().items()]
known_scopes.append('*')
known_scopes.append("*")
for scope in wanted_scopes:
if scope not in known_scopes:
raise BulkDownloaderException(f'Scope {scope} is not known to reddit')
raise BulkDownloaderException(f"Scope {scope} is not known to reddit")
@staticmethod
def split_scopes(scopes: str) -> set[str]:
scopes = re.split(r'[,: ]+', scopes)
scopes = re.split(r"[,: ]+", scopes)
return set(scopes)
def retrieve_new_token(self) -> str:
reddit = praw.Reddit(
redirect_uri='http://localhost:7634',
user_agent='obtain_refresh_token for BDFR',
redirect_uri="http://localhost:7634",
user_agent="obtain_refresh_token for BDFR",
client_id=self.client_id,
client_secret=self.client_secret)
client_secret=self.client_secret,
)
state = str(random.randint(0, 65000))
url = reddit.auth.url(self.scopes, state, 'permanent')
logger.warning('Authentication action required before the program can proceed')
logger.warning(f'Authenticate at {url}')
url = reddit.auth.url(self.scopes, state, "permanent")
logger.warning("Authentication action required before the program can proceed")
logger.warning(f"Authenticate at {url}")
client = self.receive_connection()
data = client.recv(1024).decode('utf-8')
param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&')
params = {key: value for (key, value) in [token.split('=') for token in param_tokens]}
data = client.recv(1024).decode("utf-8")
param_tokens = data.split(" ", 2)[1].split("?", 1)[1].split("&")
params = {key: value for (key, value) in [token.split("=") for token in param_tokens]}
if state != params['state']:
if state != params["state"]:
self.send_message(client)
raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
elif 'error' in params:
elif "error" in params:
self.send_message(client)
raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
@@ -70,19 +71,19 @@ class OAuth2Authenticator:
def receive_connection() -> socket.socket:
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(('0.0.0.0', 7634))
logger.log(9, 'Server listening on 0.0.0.0:7634')
server.bind(("0.0.0.0", 7634))
logger.log(9, "Server listening on 0.0.0.0:7634")
server.listen(1)
client = server.accept()[0]
server.close()
logger.log(9, 'Server closed')
logger.log(9, "Server closed")
return client
@staticmethod
def send_message(client: socket.socket, message: str = ''):
client.send(f'HTTP/1.1 200 OK\r\n\r\n{message}'.encode('utf-8'))
def send_message(client: socket.socket, message: str = ""):
client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode("utf-8"))
client.close()
@@ -94,14 +95,14 @@ class OAuth2TokenManager(praw.reddit.BaseTokenManager):
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
if authorizer.refresh_token is None:
if self.config.has_option('DEFAULT', 'user_token'):
authorizer.refresh_token = self.config.get('DEFAULT', 'user_token')
logger.log(9, 'Loaded OAuth2 token for authoriser')
if self.config.has_option("DEFAULT", "user_token"):
authorizer.refresh_token = self.config.get("DEFAULT", "user_token")
logger.log(9, "Loaded OAuth2 token for authoriser")
else:
raise RedditAuthenticationError('No auth token loaded in configuration')
raise RedditAuthenticationError("No auth token loaded in configuration")
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
self.config.set('DEFAULT', 'user_token', authorizer.refresh_token)
with open(self.config_location, 'w') as file:
self.config.set("DEFAULT", "user_token", authorizer.refresh_token)
with open(self.config_location, "w") as file:
self.config.write(file, True)
logger.log(9, f'Written OAuth2 token from authoriser to {self.config_location}')
logger.log(9, f"Written OAuth2 token from authoriser to {self.config_location}")

View File

@@ -39,7 +39,7 @@ class Resource:
try:
content = self.download_function(download_parameters)
except requests.exceptions.ConnectionError as e:
raise BulkDownloaderException(f'Could not download resource: {e}')
raise BulkDownloaderException(f"Could not download resource: {e}")
except BulkDownloaderException:
raise
if content:
@@ -51,7 +51,7 @@ class Resource:
self.hash = hashlib.md5(self.content)
def _determine_extension(self) -> Optional[str]:
extension_pattern = re.compile(r'.*(\..{3,5})$')
extension_pattern = re.compile(r".*(\..{3,5})$")
stripped_url = urllib.parse.urlsplit(self.url).path
match = re.search(extension_pattern, stripped_url)
if match:
@@ -59,27 +59,28 @@ class Resource:
@staticmethod
def http_download(url: str, download_parameters: dict) -> Optional[bytes]:
headers = download_parameters.get('headers')
headers = download_parameters.get("headers")
current_wait_time = 60
if 'max_wait_time' in download_parameters:
max_wait_time = download_parameters['max_wait_time']
if "max_wait_time" in download_parameters:
max_wait_time = download_parameters["max_wait_time"]
else:
max_wait_time = 300
while True:
try:
response = requests.get(url, headers=headers)
if re.match(r'^2\d{2}', str(response.status_code)) and response.content:
if re.match(r"^2\d{2}", str(response.status_code)) and response.content:
return response.content
elif response.status_code in (408, 429):
raise requests.exceptions.ConnectionError(f'Response code {response.status_code}')
raise requests.exceptions.ConnectionError(f"Response code {response.status_code}")
else:
raise BulkDownloaderException(
f'Unrecoverable error requesting resource: HTTP Code {response.status_code}')
f"Unrecoverable error requesting resource: HTTP Code {response.status_code}"
)
except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e:
logger.warning(f'Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}')
logger.warning(f"Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}")
time.sleep(current_wait_time)
if current_wait_time < max_wait_time:
current_wait_time += 60
else:
logger.error(f'Max wait time exceeded for resource at url {url}')
logger.error(f"Max wait time exceeded for resource at url {url}")
raise

View File

@@ -31,7 +31,7 @@ class BaseDownloader(ABC):
res = requests.get(url, cookies=cookies, headers=headers)
except requests.exceptions.RequestException as e:
logger.exception(e)
raise SiteDownloaderError(f'Failed to get page {url}')
raise SiteDownloaderError(f"Failed to get page {url}")
if res.status_code != 200:
raise ResourceNotFound(f'Server responded with {res.status_code} to {url}')
raise ResourceNotFound(f"Server responded with {res.status_code} to {url}")
return res

View File

@@ -5,8 +5,8 @@ from typing import Optional
from praw.models import Submission
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader
logger = logging.getLogger(__name__)

View File

@@ -4,8 +4,8 @@ from typing import Optional
from praw.models import Submission
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.base_downloader import BaseDownloader

View File

@@ -26,62 +26,63 @@ class DownloadFactory:
@staticmethod
def pull_lever(url: str) -> Type[BaseDownloader]:
sanitised_url = DownloadFactory.sanitise_url(url)
if re.match(r'(i\.|m\.)?imgur', sanitised_url):
if re.match(r"(i\.|m\.)?imgur", sanitised_url):
return Imgur
elif re.match(r'(i\.)?(redgifs|gifdeliverynetwork)', sanitised_url):
elif re.match(r"(i\.)?(redgifs|gifdeliverynetwork)", sanitised_url):
return Redgifs
elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url) and \
not DownloadFactory.is_web_resource(sanitised_url):
elif re.match(r".*/.*\.\w{3,4}(\?[\w;&=]*)?$", sanitised_url) and not DownloadFactory.is_web_resource(
sanitised_url
):
return Direct
elif re.match(r'erome\.com.*', sanitised_url):
elif re.match(r"erome\.com.*", sanitised_url):
return Erome
elif re.match(r'delayforreddit\.com', sanitised_url):
elif re.match(r"delayforreddit\.com", sanitised_url):
return DelayForReddit
elif re.match(r'reddit\.com/gallery/.*', sanitised_url):
elif re.match(r"reddit\.com/gallery/.*", sanitised_url):
return Gallery
elif re.match(r'patreon\.com.*', sanitised_url):
elif re.match(r"patreon\.com.*", sanitised_url):
return Gallery
elif re.match(r'gfycat\.', sanitised_url):
elif re.match(r"gfycat\.", sanitised_url):
return Gfycat
elif re.match(r'reddit\.com/r/', sanitised_url):
elif re.match(r"reddit\.com/r/", sanitised_url):
return SelfPost
elif re.match(r'(m\.)?youtu\.?be', sanitised_url):
elif re.match(r"(m\.)?youtu\.?be", sanitised_url):
return Youtube
elif re.match(r'i\.redd\.it.*', sanitised_url):
elif re.match(r"i\.redd\.it.*", sanitised_url):
return Direct
elif re.match(r'v\.redd\.it.*', sanitised_url):
elif re.match(r"v\.redd\.it.*", sanitised_url):
return VReddit
elif re.match(r'pornhub\.com.*', sanitised_url):
elif re.match(r"pornhub\.com.*", sanitised_url):
return PornHub
elif re.match(r'vidble\.com', sanitised_url):
elif re.match(r"vidble\.com", sanitised_url):
return Vidble
elif YtdlpFallback.can_handle_link(sanitised_url):
return YtdlpFallback
else:
raise NotADownloadableLinkError(f'No downloader module exists for url {url}')
raise NotADownloadableLinkError(f"No downloader module exists for url {url}")
@staticmethod
def sanitise_url(url: str) -> str:
beginning_regex = re.compile(r'\s*(www\.?)?')
beginning_regex = re.compile(r"\s*(www\.?)?")
split_url = urllib.parse.urlsplit(url)
split_url = split_url.netloc + split_url.path
split_url = re.sub(beginning_regex, '', split_url)
split_url = re.sub(beginning_regex, "", split_url)
return split_url
@staticmethod
def is_web_resource(url: str) -> bool:
web_extensions = (
'asp',
'aspx',
'cfm',
'cfml',
'css',
'htm',
'html',
'js',
'php',
'php3',
'xhtml',
"asp",
"aspx",
"cfm",
"cfml",
"css",
"htm",
"html",
"js",
"php",
"php3",
"xhtml",
)
if re.match(rf'(?i).*/.*\.({"|".join(web_extensions)})$', url):
return True

View File

@@ -23,34 +23,34 @@ class Erome(BaseDownloader):
links = self._get_links(self.post.url)
if not links:
raise SiteDownloaderError('Erome parser could not find any links')
raise SiteDownloaderError("Erome parser could not find any links")
out = []
for link in links:
if not re.match(r'https?://.*', link):
link = 'https://' + link
if not re.match(r"https?://.*", link):
link = "https://" + link
out.append(Resource(self.post, link, self.erome_download(link)))
return out
@staticmethod
def _get_links(url: str) -> set[str]:
page = Erome.retrieve_url(url)
soup = bs4.BeautifulSoup(page.text, 'html.parser')
front_images = soup.find_all('img', attrs={'class': 'lasyload'})
out = [im.get('data-src') for im in front_images]
soup = bs4.BeautifulSoup(page.text, "html.parser")
front_images = soup.find_all("img", attrs={"class": "lasyload"})
out = [im.get("data-src") for im in front_images]
videos = soup.find_all('source')
out.extend([vid.get('src') for vid in videos])
videos = soup.find_all("source")
out.extend([vid.get("src") for vid in videos])
return set(out)
@staticmethod
def erome_download(url: str) -> Callable:
download_parameters = {
'headers': {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)'
' Chrome/88.0.4324.104 Safari/537.36',
'Referer': 'https://www.erome.com/',
"headers": {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/88.0.4324.104 Safari/537.36",
"Referer": "https://www.erome.com/",
},
}
return lambda global_params: Resource.http_download(url, global_params | download_parameters)

View File

@@ -7,7 +7,6 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class BaseFallbackDownloader(BaseDownloader, ABC):
@staticmethod
@abstractmethod
def can_handle_link(url: str) -> bool:

View File

@@ -9,7 +9,9 @@ from praw.models import Submission
from bdfr.exceptions import NotADownloadableLinkError
from bdfr.resource import Resource
from bdfr.site_authenticator import SiteAuthenticator
from bdfr.site_downloaders.fallback_downloaders.fallback_downloader import BaseFallbackDownloader
from bdfr.site_downloaders.fallback_downloaders.fallback_downloader import (
BaseFallbackDownloader,
)
from bdfr.site_downloaders.youtube import Youtube
logger = logging.getLogger(__name__)
@@ -24,7 +26,7 @@ class YtdlpFallback(BaseFallbackDownloader, Youtube):
self.post,
self.post.url,
super()._download_video({}),
super().get_video_attributes(self.post.url)['ext'],
super().get_video_attributes(self.post.url)["ext"],
)
return [out]

View File

@@ -20,27 +20,27 @@ class Gallery(BaseDownloader):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
try:
image_urls = self._get_links(self.post.gallery_data['items'])
image_urls = self._get_links(self.post.gallery_data["items"])
except (AttributeError, TypeError):
try:
image_urls = self._get_links(self.post.crosspost_parent_list[0]['gallery_data']['items'])
image_urls = self._get_links(self.post.crosspost_parent_list[0]["gallery_data"]["items"])
except (AttributeError, IndexError, TypeError, KeyError):
logger.error(f'Could not find gallery data in submission {self.post.id}')
logger.exception('Gallery image find failure')
raise SiteDownloaderError('No images found in Reddit gallery')
logger.error(f"Could not find gallery data in submission {self.post.id}")
logger.exception("Gallery image find failure")
raise SiteDownloaderError("No images found in Reddit gallery")
if not image_urls:
raise SiteDownloaderError('No images found in Reddit gallery')
raise SiteDownloaderError("No images found in Reddit gallery")
return [Resource(self.post, url, Resource.retry_download(url)) for url in image_urls]
@ staticmethod
@staticmethod
def _get_links(id_dict: list[dict]) -> list[str]:
out = []
for item in id_dict:
image_id = item['media_id']
possible_extensions = ('.jpg', '.png', '.gif', '.gifv', '.jpeg')
image_id = item["media_id"]
possible_extensions = (".jpg", ".png", ".gif", ".gifv", ".jpeg")
for extension in possible_extensions:
test_url = f'https://i.redd.it/{image_id}{extension}'
test_url = f"https://i.redd.it/{image_id}{extension}"
response = requests.head(test_url)
if response.status_code == 200:
out.append(test_url)

View File

@@ -22,21 +22,23 @@ class Gfycat(Redgifs):
@staticmethod
def _get_link(url: str) -> set[str]:
gfycat_id = re.match(r'.*/(.*?)/?$', url).group(1)
url = 'https://gfycat.com/' + gfycat_id
gfycat_id = re.match(r".*/(.*?)/?$", url).group(1)
url = "https://gfycat.com/" + gfycat_id
response = Gfycat.retrieve_url(url)
if re.search(r'(redgifs|gifdeliverynetwork)', response.url):
if re.search(r"(redgifs|gifdeliverynetwork)", response.url):
url = url.lower() # Fixes error with old gfycat/redgifs links
return Redgifs._get_link(url)
soup = BeautifulSoup(response.text, 'html.parser')
content = soup.find('script', attrs={'data-react-helmet': 'true', 'type': 'application/ld+json'})
soup = BeautifulSoup(response.text, "html.parser")
content = soup.find("script", attrs={"data-react-helmet": "true", "type": "application/ld+json"})
try:
out = json.loads(content.contents[0])['video']['contentUrl']
out = json.loads(content.contents[0])["video"]["contentUrl"]
except (IndexError, KeyError, AttributeError) as e:
raise SiteDownloaderError(f'Failed to download Gfycat link {url}: {e}')
raise SiteDownloaderError(f"Failed to download Gfycat link {url}: {e}")
except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Did not receive valid JSON data: {e}')
return {out,}
raise SiteDownloaderError(f"Did not receive valid JSON data: {e}")
return {
out,
}

View File

@@ -14,7 +14,6 @@ from bdfr.site_downloaders.base_downloader import BaseDownloader
class Imgur(BaseDownloader):
def __init__(self, post: Submission):
super().__init__(post)
self.raw_data = {}
@@ -23,63 +22,63 @@ class Imgur(BaseDownloader):
self.raw_data = self._get_data(self.post.url)
out = []
if 'album_images' in self.raw_data:
images = self.raw_data['album_images']
for image in images['images']:
if "album_images" in self.raw_data:
images = self.raw_data["album_images"]
for image in images["images"]:
out.append(self._compute_image_url(image))
else:
out.append(self._compute_image_url(self.raw_data))
return out
def _compute_image_url(self, image: dict) -> Resource:
ext = self._validate_extension(image['ext'])
if image.get('prefer_video', False):
ext = '.mp4'
ext = self._validate_extension(image["ext"])
if image.get("prefer_video", False):
ext = ".mp4"
image_url = 'https://i.imgur.com/' + image['hash'] + ext
image_url = "https://i.imgur.com/" + image["hash"] + ext
return Resource(self.post, image_url, Resource.retry_download(image_url))
@staticmethod
def _get_data(link: str) -> dict:
try:
imgur_id = re.match(r'.*/(.*?)(\..{0,})?$', link).group(1)
gallery = 'a/' if re.search(r'.*/(.*?)(gallery/|a/)', link) else ''
link = f'https://imgur.com/{gallery}{imgur_id}'
imgur_id = re.match(r".*/(.*?)(\..{0,})?$", link).group(1)
gallery = "a/" if re.search(r".*/(.*?)(gallery/|a/)", link) else ""
link = f"https://imgur.com/{gallery}{imgur_id}"
except AttributeError:
raise SiteDownloaderError(f'Could not extract Imgur ID from {link}')
raise SiteDownloaderError(f"Could not extract Imgur ID from {link}")
res = Imgur.retrieve_url(link, cookies={'over18': '1', 'postpagebeta': '0'})
res = Imgur.retrieve_url(link, cookies={"over18": "1", "postpagebeta": "0"})
soup = bs4.BeautifulSoup(res.text, 'html.parser')
scripts = soup.find_all('script', attrs={'type': 'text/javascript'})
scripts = [script.string.replace('\n', '') for script in scripts if script.string]
soup = bs4.BeautifulSoup(res.text, "html.parser")
scripts = soup.find_all("script", attrs={"type": "text/javascript"})
scripts = [script.string.replace("\n", "") for script in scripts if script.string]
script_regex = re.compile(r'\s*\(function\(widgetFactory\)\s*{\s*widgetFactory\.mergeConfig\(\'gallery\'')
script_regex = re.compile(r"\s*\(function\(widgetFactory\)\s*{\s*widgetFactory\.mergeConfig\(\'gallery\'")
chosen_script = list(filter(lambda s: re.search(script_regex, s), scripts))
if len(chosen_script) != 1:
raise SiteDownloaderError(f'Could not read page source from {link}')
raise SiteDownloaderError(f"Could not read page source from {link}")
chosen_script = chosen_script[0]
outer_regex = re.compile(r'widgetFactory\.mergeConfig\(\'gallery\', ({.*})\);')
inner_regex = re.compile(r'image\s*:(.*),\s*group')
outer_regex = re.compile(r"widgetFactory\.mergeConfig\(\'gallery\', ({.*})\);")
inner_regex = re.compile(r"image\s*:(.*),\s*group")
try:
image_dict = re.search(outer_regex, chosen_script).group(1)
image_dict = re.search(inner_regex, image_dict).group(1)
except AttributeError:
raise SiteDownloaderError(f'Could not find image dictionary in page source')
raise SiteDownloaderError(f"Could not find image dictionary in page source")
try:
image_dict = json.loads(image_dict)
except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Could not parse received dict as JSON: {e}')
raise SiteDownloaderError(f"Could not parse received dict as JSON: {e}")
return image_dict
@staticmethod
def _validate_extension(extension_suffix: str) -> str:
extension_suffix = re.sub(r'\?.*', '', extension_suffix)
possible_extensions = ('.jpg', '.png', '.mp4', '.gif')
extension_suffix = re.sub(r"\?.*", "", extension_suffix)
possible_extensions = (".jpg", ".png", ".mp4", ".gif")
selection = [ext for ext in possible_extensions if ext == extension_suffix]
if len(selection) == 1:
return selection[0]

View File

@@ -20,11 +20,11 @@ class PornHub(Youtube):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
ytdl_options = {
'format': 'best',
'nooverwrites': True,
"format": "best",
"nooverwrites": True,
}
if video_attributes := super().get_video_attributes(self.post.url):
extension = video_attributes['ext']
extension = video_attributes["ext"]
else:
raise SiteDownloaderError()

View File

@@ -2,9 +2,9 @@
import json
import re
import requests
from typing import Optional
import requests
from praw.models import Submission
from bdfr.exceptions import SiteDownloaderError
@@ -24,52 +24,53 @@ class Redgifs(BaseDownloader):
@staticmethod
def _get_link(url: str) -> set[str]:
try:
redgif_id = re.match(r'.*/(.*?)(\..{0,})?$', url).group(1)
redgif_id = re.match(r".*/(.*?)(\..{0,})?$", url).group(1)
except AttributeError:
raise SiteDownloaderError(f'Could not extract Redgifs ID from {url}')
raise SiteDownloaderError(f"Could not extract Redgifs ID from {url}")
auth_token = json.loads(Redgifs.retrieve_url('https://api.redgifs.com/v2/auth/temporary').text)['token']
auth_token = json.loads(Redgifs.retrieve_url("https://api.redgifs.com/v2/auth/temporary").text)["token"]
if not auth_token:
raise SiteDownloaderError('Unable to retrieve Redgifs API token')
raise SiteDownloaderError("Unable to retrieve Redgifs API token")
headers = {
'referer': 'https://www.redgifs.com/',
'origin': 'https://www.redgifs.com',
'content-type': 'application/json',
'Authorization': f'Bearer {auth_token}',
"referer": "https://www.redgifs.com/",
"origin": "https://www.redgifs.com",
"content-type": "application/json",
"Authorization": f"Bearer {auth_token}",
}
content = Redgifs.retrieve_url(f'https://api.redgifs.com/v2/gifs/{redgif_id}', headers=headers)
content = Redgifs.retrieve_url(f"https://api.redgifs.com/v2/gifs/{redgif_id}", headers=headers)
if content is None:
raise SiteDownloaderError('Could not read the page source')
raise SiteDownloaderError("Could not read the page source")
try:
response_json = json.loads(content.text)
except json.JSONDecodeError as e:
raise SiteDownloaderError(f'Received data was not valid JSON: {e}')
raise SiteDownloaderError(f"Received data was not valid JSON: {e}")
out = set()
try:
if response_json['gif']['type'] == 1: # type 1 is a video
if requests.get(response_json['gif']['urls']['hd'], headers=headers).ok:
out.add(response_json['gif']['urls']['hd'])
if response_json["gif"]["type"] == 1: # type 1 is a video
if requests.get(response_json["gif"]["urls"]["hd"], headers=headers).ok:
out.add(response_json["gif"]["urls"]["hd"])
else:
out.add(response_json['gif']['urls']['sd'])
elif response_json['gif']['type'] == 2: # type 2 is an image
if response_json['gif']['gallery']:
out.add(response_json["gif"]["urls"]["sd"])
elif response_json["gif"]["type"] == 2: # type 2 is an image
if response_json["gif"]["gallery"]:
content = Redgifs.retrieve_url(
f'https://api.redgifs.com/v2/gallery/{response_json["gif"]["gallery"]}')
f'https://api.redgifs.com/v2/gallery/{response_json["gif"]["gallery"]}'
)
response_json = json.loads(content.text)
out = {p['urls']['hd'] for p in response_json['gifs']}
out = {p["urls"]["hd"] for p in response_json["gifs"]}
else:
out.add(response_json['gif']['urls']['hd'])
out.add(response_json["gif"]["urls"]["hd"])
else:
raise KeyError
except (KeyError, AttributeError):
raise SiteDownloaderError('Failed to find JSON data in page')
raise SiteDownloaderError("Failed to find JSON data in page")
# Update subdomain if old one is returned
out = {re.sub('thumbs2', 'thumbs3', link) for link in out}
out = {re.sub('thumbs3', 'thumbs4', link) for link in out}
out = {re.sub("thumbs2", "thumbs3", link) for link in out}
out = {re.sub("thumbs3", "thumbs4", link) for link in out}
return out

View File

@@ -17,27 +17,29 @@ class SelfPost(BaseDownloader):
super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
out = Resource(self.post, self.post.url, lambda: None, '.txt')
out.content = self.export_to_string().encode('utf-8')
out = Resource(self.post, self.post.url, lambda: None, ".txt")
out.content = self.export_to_string().encode("utf-8")
out.create_hash()
return [out]
def export_to_string(self) -> str:
"""Self posts are formatted here"""
content = ("## ["
+ self.post.fullname
+ "]("
+ self.post.url
+ ")\n"
+ self.post.selftext
+ "\n\n---\n\n"
+ "submitted to [r/"
+ self.post.subreddit.title
+ "](https://www.reddit.com/r/"
+ self.post.subreddit.title
+ ") by [u/"
+ (self.post.author.name if self.post.author else "DELETED")
+ "](https://www.reddit.com/user/"
+ (self.post.author.name if self.post.author else "DELETED")
+ ")")
content = (
"## ["
+ self.post.fullname
+ "]("
+ self.post.url
+ ")\n"
+ self.post.selftext
+ "\n\n---\n\n"
+ "submitted to [r/"
+ self.post.subreddit.title
+ "](https://www.reddit.com/r/"
+ self.post.subreddit.title
+ ") by [u/"
+ (self.post.author.name if self.post.author else "DELETED")
+ "](https://www.reddit.com/user/"
+ (self.post.author.name if self.post.author else "DELETED")
+ ")"
)
return content

View File

@@ -25,30 +25,30 @@ class Vidble(BaseDownloader):
try:
res = self.get_links(self.post.url)
except AttributeError:
raise SiteDownloaderError(f'Could not read page at {self.post.url}')
raise SiteDownloaderError(f"Could not read page at {self.post.url}")
if not res:
raise SiteDownloaderError(rf'No resources found at {self.post.url}')
raise SiteDownloaderError(rf"No resources found at {self.post.url}")
res = [Resource(self.post, r, Resource.retry_download(r)) for r in res]
return res
@staticmethod
def get_links(url: str) -> set[str]:
if not re.search(r'vidble.com/(show/|album/|watch\?v)', url):
url = re.sub(r'/(\w*?)$', r'/show/\1', url)
if not re.search(r"vidble.com/(show/|album/|watch\?v)", url):
url = re.sub(r"/(\w*?)$", r"/show/\1", url)
page = requests.get(url)
soup = bs4.BeautifulSoup(page.text, 'html.parser')
content_div = soup.find('div', attrs={'id': 'ContentPlaceHolder1_divContent'})
images = content_div.find_all('img')
images = [i.get('src') for i in images]
videos = content_div.find_all('source', attrs={'type': 'video/mp4'})
videos = [v.get('src') for v in videos]
soup = bs4.BeautifulSoup(page.text, "html.parser")
content_div = soup.find("div", attrs={"id": "ContentPlaceHolder1_divContent"})
images = content_div.find_all("img")
images = [i.get("src") for i in images]
videos = content_div.find_all("source", attrs={"type": "video/mp4"})
videos = [v.get("src") for v in videos]
resources = filter(None, itertools.chain(images, videos))
resources = ['https://www.vidble.com' + r for r in resources]
resources = ["https://www.vidble.com" + r for r in resources]
resources = [Vidble.change_med_url(r) for r in resources]
return set(resources)
@staticmethod
def change_med_url(url: str) -> str:
out = re.sub(r'_med(\..{3,4})$', r'\1', url)
out = re.sub(r"_med(\..{3,4})$", r"\1", url)
return out

View File

@@ -22,18 +22,18 @@ class VReddit(Youtube):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
ytdl_options = {
'playlistend': 1,
'nooverwrites': True,
"playlistend": 1,
"nooverwrites": True,
}
download_function = self._download_video(ytdl_options)
extension = self.get_video_attributes(self.post.url)['ext']
extension = self.get_video_attributes(self.post.url)["ext"]
res = Resource(self.post, self.post.url, download_function, extension)
return [res]
@staticmethod
def get_video_attributes(url: str) -> dict:
result = VReddit.get_video_data(url)
if 'ext' in result:
if "ext" in result:
return result
else:
try:
@@ -41,4 +41,4 @@ class VReddit(Youtube):
return result
except Exception as e:
logger.exception(e)
raise NotADownloadableLinkError(f'Video info extraction failed for {url}')
raise NotADownloadableLinkError(f"Video info extraction failed for {url}")

View File

@@ -22,57 +22,62 @@ class Youtube(BaseDownloader):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
ytdl_options = {
'format': 'best',
'playlistend': 1,
'nooverwrites': True,
"format": "best",
"playlistend": 1,
"nooverwrites": True,
}
download_function = self._download_video(ytdl_options)
extension = self.get_video_attributes(self.post.url)['ext']
extension = self.get_video_attributes(self.post.url)["ext"]
res = Resource(self.post, self.post.url, download_function, extension)
return [res]
def _download_video(self, ytdl_options: dict) -> Callable:
yt_logger = logging.getLogger('youtube-dl')
yt_logger = logging.getLogger("youtube-dl")
yt_logger.setLevel(logging.CRITICAL)
ytdl_options['quiet'] = True
ytdl_options['logger'] = yt_logger
ytdl_options["quiet"] = True
ytdl_options["logger"] = yt_logger
def download(_: dict) -> bytes:
with tempfile.TemporaryDirectory() as temp_dir:
download_path = Path(temp_dir).resolve()
ytdl_options['outtmpl'] = str(download_path) + '/' + 'test.%(ext)s'
ytdl_options["outtmpl"] = str(download_path) + "/" + "test.%(ext)s"
try:
with yt_dlp.YoutubeDL(ytdl_options) as ydl:
ydl.download([self.post.url])
except yt_dlp.DownloadError as e:
raise SiteDownloaderError(f'Youtube download failed: {e}')
raise SiteDownloaderError(f"Youtube download failed: {e}")
downloaded_files = list(download_path.iterdir())
if downloaded_files:
downloaded_file = downloaded_files[0]
else:
raise NotADownloadableLinkError(f"No media exists in the URL {self.post.url}")
with downloaded_file.open('rb') as file:
with downloaded_file.open("rb") as file:
content = file.read()
return content
return download
@staticmethod
def get_video_data(url: str) -> dict:
yt_logger = logging.getLogger('youtube-dl')
yt_logger = logging.getLogger("youtube-dl")
yt_logger.setLevel(logging.CRITICAL)
with yt_dlp.YoutubeDL({'logger': yt_logger, }) as ydl:
with yt_dlp.YoutubeDL(
{
"logger": yt_logger,
}
) as ydl:
try:
result = ydl.extract_info(url, download=False)
except Exception as e:
logger.exception(e)
raise NotADownloadableLinkError(f'Video info extraction failed for {url}')
raise NotADownloadableLinkError(f"Video info extraction failed for {url}")
return result
@staticmethod
def get_video_attributes(url: str) -> dict:
result = Youtube.get_video_data(url)
if 'ext' in result:
if "ext" in result:
return result
else:
raise NotADownloadableLinkError(f'Video info extraction failed for {url}')
raise NotADownloadableLinkError(f"Video info extraction failed for {url}")