Format according to the black standard

This commit is contained in:
Serene-Arc
2022-12-03 15:11:17 +10:00
parent 96cd7d7147
commit 0873a4a2b2
60 changed files with 2160 additions and 1790 deletions

View File

@@ -41,18 +41,18 @@ class RedditTypes:
TOP = auto()
class TimeType(Enum):
ALL = 'all'
DAY = 'day'
HOUR = 'hour'
MONTH = 'month'
WEEK = 'week'
YEAR = 'year'
ALL = "all"
DAY = "day"
HOUR = "hour"
MONTH = "month"
WEEK = "week"
YEAR = "year"
class RedditConnector(metaclass=ABCMeta):
def __init__(self, args: Configuration):
self.args = args
self.config_directories = appdirs.AppDirs('bdfr', 'BDFR')
self.config_directories = appdirs.AppDirs("bdfr", "BDFR")
self.run_time = datetime.now().isoformat()
self._setup_internal_objects()
@@ -68,13 +68,13 @@ class RedditConnector(metaclass=ABCMeta):
self.parse_disabled_modules()
self.download_filter = self.create_download_filter()
logger.log(9, 'Created download filter')
logger.log(9, "Created download filter")
self.time_filter = self.create_time_filter()
logger.log(9, 'Created time filter')
logger.log(9, "Created time filter")
self.sort_filter = self.create_sort_filter()
logger.log(9, 'Created sort filter')
logger.log(9, "Created sort filter")
self.file_name_formatter = self.create_file_name_formatter()
logger.log(9, 'Create file name formatter')
logger.log(9, "Create file name formatter")
self.create_reddit_instance()
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
@@ -88,7 +88,7 @@ class RedditConnector(metaclass=ABCMeta):
self.master_hash_list = {}
self.authenticator = self.create_authenticator()
logger.log(9, 'Created site authenticator')
logger.log(9, "Created site authenticator")
self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit)
self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit}
@@ -96,18 +96,18 @@ class RedditConnector(metaclass=ABCMeta):
def read_config(self):
"""Read any cfg values that need to be processed"""
if self.args.max_wait_time is None:
self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time', fallback=120)
logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds')
self.args.max_wait_time = self.cfg_parser.getint("DEFAULT", "max_wait_time", fallback=120)
logger.debug(f"Setting maximum download wait time to {self.args.max_wait_time} seconds")
if self.args.time_format is None:
option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO')
if re.match(r'^[\s\'\"]*$', option):
option = 'ISO'
logger.debug(f'Setting datetime format string to {option}')
option = self.cfg_parser.get("DEFAULT", "time_format", fallback="ISO")
if re.match(r"^[\s\'\"]*$", option):
option = "ISO"
logger.debug(f"Setting datetime format string to {option}")
self.args.time_format = option
if not self.args.disable_module:
self.args.disable_module = [self.cfg_parser.get('DEFAULT', 'disabled_modules', fallback='')]
self.args.disable_module = [self.cfg_parser.get("DEFAULT", "disabled_modules", fallback="")]
# Update config on disk
with open(self.config_location, 'w') as file:
with open(self.config_location, "w") as file:
self.cfg_parser.write(file)
def parse_disabled_modules(self):
@@ -119,48 +119,48 @@ class RedditConnector(metaclass=ABCMeta):
def create_reddit_instance(self):
if self.args.authenticate:
logger.debug('Using authenticated Reddit instance')
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
logger.log(9, 'Commencing OAuth2 authentication')
scopes = self.cfg_parser.get('DEFAULT', 'scopes', fallback='identity, history, read, save')
logger.debug("Using authenticated Reddit instance")
if not self.cfg_parser.has_option("DEFAULT", "user_token"):
logger.log(9, "Commencing OAuth2 authentication")
scopes = self.cfg_parser.get("DEFAULT", "scopes", fallback="identity, history, read, save")
scopes = OAuth2Authenticator.split_scopes(scopes)
oauth2_authenticator = OAuth2Authenticator(
scopes,
self.cfg_parser.get('DEFAULT', 'client_id'),
self.cfg_parser.get('DEFAULT', 'client_secret'),
self.cfg_parser.get("DEFAULT", "client_id"),
self.cfg_parser.get("DEFAULT", "client_secret"),
)
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token
with open(self.config_location, 'w') as file:
self.cfg_parser["DEFAULT"]["user_token"] = token
with open(self.config_location, "w") as file:
self.cfg_parser.write(file, True)
token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location)
self.authenticated = True
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
client_id=self.cfg_parser.get("DEFAULT", "client_id"),
client_secret=self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(),
token_manager=token_manager,
)
else:
logger.debug('Using unauthenticated Reddit instance')
logger.debug("Using unauthenticated Reddit instance")
self.authenticated = False
self.reddit_instance = praw.Reddit(
client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
client_id=self.cfg_parser.get("DEFAULT", "client_id"),
client_secret=self.cfg_parser.get("DEFAULT", "client_secret"),
user_agent=socket.gethostname(),
)
def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
master_list = []
master_list.extend(self.get_subreddits())
logger.log(9, 'Retrieved subreddits')
logger.log(9, "Retrieved subreddits")
master_list.extend(self.get_multireddits())
logger.log(9, 'Retrieved multireddits')
logger.log(9, "Retrieved multireddits")
master_list.extend(self.get_user_data())
logger.log(9, 'Retrieved user data')
logger.log(9, "Retrieved user data")
master_list.extend(self.get_submissions_from_link())
logger.log(9, 'Retrieved submissions for given links')
logger.log(9, "Retrieved submissions for given links")
return master_list
def determine_directories(self):
@@ -178,37 +178,37 @@ class RedditConnector(metaclass=ABCMeta):
self.config_location = cfg_path
return
possible_paths = [
Path('./config.cfg'),
Path('./default_config.cfg'),
Path(self.config_directory, 'config.cfg'),
Path(self.config_directory, 'default_config.cfg'),
Path("./config.cfg"),
Path("./default_config.cfg"),
Path(self.config_directory, "config.cfg"),
Path(self.config_directory, "default_config.cfg"),
]
self.config_location = None
for path in possible_paths:
if path.resolve().expanduser().exists():
self.config_location = path
logger.debug(f'Loading configuration from {path}')
logger.debug(f"Loading configuration from {path}")
break
if not self.config_location:
with importlib.resources.path('bdfr', 'default_config.cfg') as path:
with importlib.resources.path("bdfr", "default_config.cfg") as path:
self.config_location = path
shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg'))
shutil.copy(self.config_location, Path(self.config_directory, "default_config.cfg"))
if not self.config_location:
raise errors.BulkDownloaderException('Could not find a configuration file to load')
raise errors.BulkDownloaderException("Could not find a configuration file to load")
self.cfg_parser.read(self.config_location)
def create_file_logger(self):
main_logger = logging.getLogger()
if self.args.log is None:
log_path = Path(self.config_directory, 'log_output.txt')
log_path = Path(self.config_directory, "log_output.txt")
else:
log_path = Path(self.args.log).resolve().expanduser()
if not log_path.parent.exists():
raise errors.BulkDownloaderException(f'Designated location for logfile does not exist')
backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3)
raise errors.BulkDownloaderException(f"Designated location for logfile does not exist")
backup_count = self.cfg_parser.getint("DEFAULT", "backup_log_count", fallback=3)
file_handler = logging.handlers.RotatingFileHandler(
log_path,
mode='a',
mode="a",
backupCount=backup_count,
)
if log_path.exists():
@@ -216,10 +216,11 @@ class RedditConnector(metaclass=ABCMeta):
file_handler.doRollover()
except PermissionError:
logger.critical(
'Cannot rollover logfile, make sure this is the only '
'BDFR process or specify alternate logfile location')
"Cannot rollover logfile, make sure this is the only "
"BDFR process or specify alternate logfile location"
)
raise
formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s')
formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s")
file_handler.setFormatter(formatter)
file_handler.setLevel(0)
@@ -227,16 +228,16 @@ class RedditConnector(metaclass=ABCMeta):
@staticmethod
def sanitise_subreddit_name(subreddit: str) -> str:
pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$')
pattern = re.compile(r"^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$")
match = re.match(pattern, subreddit)
if not match:
raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}')
raise errors.BulkDownloaderException(f"Could not find subreddit name in string {subreddit}")
return match.group(1)
@staticmethod
def split_args_input(entries: list[str]) -> set[str]:
all_entries = []
split_pattern = re.compile(r'[,;]\s?')
split_pattern = re.compile(r"[,;]\s?")
for entry in entries:
results = re.split(split_pattern, entry)
all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results])
@@ -251,13 +252,13 @@ class RedditConnector(metaclass=ABCMeta):
subscribed_subreddits = list(self.reddit_instance.user.subreddits(limit=None))
subscribed_subreddits = {s.display_name for s in subscribed_subreddits}
except prawcore.InsufficientScope:
logger.error('BDFR has insufficient scope to access subreddit lists')
logger.error("BDFR has insufficient scope to access subreddit lists")
else:
logger.error('Cannot find subscribed subreddits without an authenticated instance')
logger.error("Cannot find subscribed subreddits without an authenticated instance")
if self.args.subreddit or subscribed_subreddits:
for reddit in self.split_args_input(self.args.subreddit) | subscribed_subreddits:
if reddit == 'friends' and self.authenticated is False:
logger.error('Cannot read friends subreddit without an authenticated instance')
if reddit == "friends" and self.authenticated is False:
logger.error("Cannot read friends subreddit without an authenticated instance")
continue
try:
reddit = self.reddit_instance.subreddit(reddit)
@@ -267,26 +268,29 @@ class RedditConnector(metaclass=ABCMeta):
logger.error(e)
continue
if self.args.search:
out.append(reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
time_filter=self.time_filter.value,
))
out.append(
reddit.search(
self.args.search,
sort=self.sort_filter.name.lower(),
limit=self.args.limit,
time_filter=self.time_filter.value,
)
)
logger.debug(
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"'
)
else:
out.append(self.create_filtered_listing_generator(reddit))
logger.debug(f'Added submissions from subreddit {reddit}')
logger.debug(f"Added submissions from subreddit {reddit}")
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
logger.error(f"Failed to get submissions for subreddit {reddit}: {e}")
return out
def resolve_user_name(self, in_name: str) -> str:
if in_name == 'me':
if in_name == "me":
if self.authenticated:
resolved_name = self.reddit_instance.user.me().name
logger.log(9, f'Resolved user to {resolved_name}')
logger.log(9, f"Resolved user to {resolved_name}")
return resolved_name
else:
logger.warning('To use "me" as a user, an authenticated Reddit instance must be used')
@@ -318,7 +322,7 @@ class RedditConnector(metaclass=ABCMeta):
def get_multireddits(self) -> list[Iterator]:
if self.args.multireddit:
if len(self.args.user) != 1:
logger.error(f'Only 1 user can be supplied when retrieving from multireddits')
logger.error(f"Only 1 user can be supplied when retrieving from multireddits")
return []
out = []
for multi in self.split_args_input(self.args.multireddit):
@@ -327,9 +331,9 @@ class RedditConnector(metaclass=ABCMeta):
if not multi.subreddits:
raise errors.BulkDownloaderException
out.append(self.create_filtered_listing_generator(multi))
logger.debug(f'Added submissions from multireddit {multi}')
logger.debug(f"Added submissions from multireddit {multi}")
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
logger.error(f"Failed to get submissions for multireddit {multi}: {e}")
return out
else:
return []
@@ -344,7 +348,7 @@ class RedditConnector(metaclass=ABCMeta):
def get_user_data(self) -> list[Iterator]:
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
if not self.args.user:
logger.warning('At least one user must be supplied to download user data')
logger.warning("At least one user must be supplied to download user data")
return []
generators = []
for user in self.args.user:
@@ -354,18 +358,20 @@ class RedditConnector(metaclass=ABCMeta):
logger.error(e)
continue
if self.args.submitted:
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
generators.append(self.create_filtered_listing_generator(
self.reddit_instance.redditor(user).submissions,
))
logger.debug(f"Retrieving submitted posts of user {self.args.user}")
generators.append(
self.create_filtered_listing_generator(
self.reddit_instance.redditor(user).submissions,
)
)
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
logger.warning('Accessing user lists requires authentication')
logger.warning("Accessing user lists requires authentication")
else:
if self.args.upvoted:
logger.debug(f'Retrieving upvoted posts of user {self.args.user}')
logger.debug(f"Retrieving upvoted posts of user {self.args.user}")
generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit))
if self.args.saved:
logger.debug(f'Retrieving saved posts of user {self.args.user}')
logger.debug(f"Retrieving saved posts of user {self.args.user}")
generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit))
return generators
else:
@@ -377,10 +383,10 @@ class RedditConnector(metaclass=ABCMeta):
if user.id:
return
except prawcore.exceptions.NotFound:
raise errors.BulkDownloaderException(f'Could not find user {name}')
raise errors.BulkDownloaderException(f"Could not find user {name}")
except AttributeError:
if hasattr(user, 'is_suspended'):
raise errors.BulkDownloaderException(f'User {name} is banned')
if hasattr(user, "is_suspended"):
raise errors.BulkDownloaderException(f"User {name} is banned")
def create_file_name_formatter(self) -> FileNameFormatter:
return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format)
@@ -409,7 +415,7 @@ class RedditConnector(metaclass=ABCMeta):
@staticmethod
def check_subreddit_status(subreddit: praw.models.Subreddit):
if subreddit.display_name in ('all', 'friends'):
if subreddit.display_name in ("all", "friends"):
return
try:
assert subreddit.id
@@ -418,7 +424,7 @@ class RedditConnector(metaclass=ABCMeta):
except prawcore.Redirect:
raise errors.BulkDownloaderException(f"Source {subreddit.display_name} does not exist")
except prawcore.Forbidden:
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped')
raise errors.BulkDownloaderException(f"Source {subreddit.display_name} is private and cannot be scraped")
@staticmethod
def read_id_files(file_locations: list[str]) -> set[str]:
@@ -426,9 +432,9 @@ class RedditConnector(metaclass=ABCMeta):
for id_file in file_locations:
id_file = Path(id_file).resolve().expanduser()
if not id_file.exists():
logger.warning(f'ID file at {id_file} does not exist')
logger.warning(f"ID file at {id_file} does not exist")
continue
with id_file.open('r') as file:
with id_file.open("r") as file:
for line in file:
out.append(line.strip())
return set(out)