Move args to instance variable
This commit is contained in:
@@ -40,21 +40,22 @@ class RedditTypes:
|
|||||||
|
|
||||||
class RedditDownloader:
|
class RedditDownloader:
|
||||||
def __init__(self, args: argparse.Namespace):
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
self.config_directories = appdirs.AppDirs('bulk_reddit_downloader')
|
self.config_directories = appdirs.AppDirs('bulk_reddit_downloader')
|
||||||
self.run_time = datetime.now().isoformat()
|
self.run_time = datetime.now().isoformat()
|
||||||
self._setup_internal_objects(args)
|
self._setup_internal_objects()
|
||||||
|
|
||||||
self.reddit_lists = self._retrieve_reddit_lists(args)
|
self.reddit_lists = self._retrieve_reddit_lists()
|
||||||
|
|
||||||
def _setup_internal_objects(self, args: argparse.Namespace):
|
def _setup_internal_objects(self):
|
||||||
self.download_filter = RedditDownloader._create_download_filter(args)
|
self.download_filter = self._create_download_filter()
|
||||||
self.time_filter = RedditDownloader._create_time_filter(args)
|
self.time_filter = self._create_time_filter()
|
||||||
self.sort_filter = RedditDownloader._create_sort_filter(args)
|
self.sort_filter = self._create_sort_filter()
|
||||||
self.file_name_formatter = RedditDownloader._create_file_name_formatter(args)
|
self.file_name_formatter = self._create_file_name_formatter()
|
||||||
self._determine_directories(args)
|
self._determine_directories()
|
||||||
self._create_file_logger()
|
self._create_file_logger()
|
||||||
self.master_hash_list = []
|
self.master_hash_list = []
|
||||||
self._load_config(args)
|
self._load_config()
|
||||||
if self.cfg_parser.has_option('DEFAULT', 'username') and self.cfg_parser.has_option('DEFAULT', 'password'):
|
if self.cfg_parser.has_option('DEFAULT', 'username') and self.cfg_parser.has_option('DEFAULT', 'password'):
|
||||||
self.authenticated = True
|
self.authenticated = True
|
||||||
|
|
||||||
@@ -69,21 +70,21 @@ class RedditDownloader:
|
|||||||
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
|
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
|
||||||
user_agent=socket.gethostname())
|
user_agent=socket.gethostname())
|
||||||
|
|
||||||
def _retrieve_reddit_lists(self, args: argparse.Namespace) -> list[praw.models.ListingGenerator]:
|
def _retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]:
|
||||||
master_list = []
|
master_list = []
|
||||||
master_list.extend(self._get_subreddits(args))
|
master_list.extend(self._get_subreddits())
|
||||||
master_list.extend(self._get_multireddits(args))
|
master_list.extend(self._get_multireddits())
|
||||||
master_list.extend(self._get_user_data(args))
|
master_list.extend(self._get_user_data())
|
||||||
return master_list
|
return master_list
|
||||||
|
|
||||||
def _determine_directories(self, args: argparse.Namespace):
|
def _determine_directories(self):
|
||||||
self.download_directory = Path(args.directory)
|
self.download_directory = Path(self.args.directory)
|
||||||
self.logfile_directory = self.download_directory / 'LOG_FILES'
|
self.logfile_directory = self.download_directory / 'LOG_FILES'
|
||||||
self.config_directory = self.config_directories.user_config_dir
|
self.config_directory = self.config_directories.user_config_dir
|
||||||
|
|
||||||
def _load_config(self, args: argparse.Namespace):
|
def _load_config(self):
|
||||||
self.cfg_parser = configparser.ConfigParser()
|
self.cfg_parser = configparser.ConfigParser()
|
||||||
if args.use_local_config and Path('./config.cfg').exists():
|
if self.args.use_local_config and Path('./config.cfg').exists():
|
||||||
self.cfg_parser.read(Path('./config.cfg'))
|
self.cfg_parser.read(Path('./config.cfg'))
|
||||||
else:
|
else:
|
||||||
self.cfg_parser.read(Path('./default_config.cfg').resolve())
|
self.cfg_parser.read(Path('./default_config.cfg').resolve())
|
||||||
@@ -97,11 +98,11 @@ class RedditDownloader:
|
|||||||
|
|
||||||
main_logger.addHandler(file_handler)
|
main_logger.addHandler(file_handler)
|
||||||
|
|
||||||
def _get_subreddits(self, args: argparse.Namespace) -> list[praw.models.ListingGenerator]:
|
def _get_subreddits(self) -> list[praw.models.ListingGenerator]:
|
||||||
if args.subreddit:
|
if self.args.subreddit:
|
||||||
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in args.subreddit]
|
subreddits = [self.reddit_instance.subreddit(chosen_subreddit) for chosen_subreddit in self.args.subreddit]
|
||||||
if args.search:
|
if self.args.search:
|
||||||
return [reddit.search(args.search, sort=self.sort_filter.name.lower()) for reddit in subreddits]
|
return [reddit.search(self.args.search, sort=self.sort_filter.name.lower()) for reddit in subreddits]
|
||||||
else:
|
else:
|
||||||
if self.sort_filter is RedditTypes.SortType.NEW:
|
if self.sort_filter is RedditTypes.SortType.NEW:
|
||||||
sort_function = praw.models.Subreddit.new
|
sort_function = praw.models.Subreddit.new
|
||||||
@@ -115,25 +116,25 @@ class RedditDownloader:
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_multireddits(self, args: argparse.Namespace) -> list[praw.models.ListingGenerator]:
|
def _get_multireddits(self) -> list[praw.models.ListingGenerator]:
|
||||||
if args.multireddit:
|
if self.args.multireddit:
|
||||||
if self.authenticated:
|
if self.authenticated:
|
||||||
return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in args.multireddit]
|
return [self.reddit_instance.multireddit(m_reddit_choice) for m_reddit_choice in self.args.multireddit]
|
||||||
else:
|
else:
|
||||||
raise RedditAuthenticationError('Accessing multireddits requires authentication')
|
raise RedditAuthenticationError('Accessing multireddits requires authentication')
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_user_data(self, args: argparse.Namespace) -> list[praw.models.ListingGenerator]:
|
def _get_user_data(self) -> list[praw.models.ListingGenerator]:
|
||||||
if any((args.upvoted, args.submitted, args.saved)):
|
if any((self.args.upvoted, self.args.submitted, self.args.saved)):
|
||||||
if self.authenticated:
|
if self.authenticated:
|
||||||
generators = []
|
generators = []
|
||||||
if args.upvoted:
|
if self.args.upvoted:
|
||||||
generators.append(self.reddit_instance.redditor(args.user).upvoted)
|
generators.append(self.reddit_instance.redditor(self.args.user).upvoted)
|
||||||
if args.submitted:
|
if self.args.submitted:
|
||||||
generators.append(self.reddit_instance.redditor(args.user).submissions)
|
generators.append(self.reddit_instance.redditor(self.args.user).submissions)
|
||||||
if args.saved:
|
if self.args.saved:
|
||||||
generators.append(self.reddit_instance.redditor(args.user).saved)
|
generators.append(self.reddit_instance.redditor(self.args.user).saved)
|
||||||
|
|
||||||
return generators
|
return generators
|
||||||
else:
|
else:
|
||||||
@@ -141,34 +142,30 @@ class RedditDownloader:
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
def _create_file_name_formatter(self) -> FileNameFormatter:
|
||||||
def _create_file_name_formatter(args: argparse.Namespace) -> FileNameFormatter:
|
return FileNameFormatter(self.args.set_filename, self.args.set_folderpath)
|
||||||
return FileNameFormatter(args.set_filename, args.set_folderpath)
|
|
||||||
|
|
||||||
@staticmethod
|
def _create_time_filter(self) -> RedditTypes.TimeType:
|
||||||
def _create_time_filter(args: argparse.Namespace) -> RedditTypes.TimeType:
|
|
||||||
try:
|
try:
|
||||||
return RedditTypes.TimeType[args.sort.upper()]
|
return RedditTypes.TimeType[self.args.sort.upper()]
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
return RedditTypes.TimeType.ALL
|
return RedditTypes.TimeType.ALL
|
||||||
|
|
||||||
@staticmethod
|
def _create_sort_filter(self) -> RedditTypes.SortType:
|
||||||
def _create_sort_filter(args: argparse.Namespace) -> RedditTypes.SortType:
|
|
||||||
try:
|
try:
|
||||||
return RedditTypes.SortType[args.time.upper()]
|
return RedditTypes.SortType[self.args.time.upper()]
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
return RedditTypes.SortType.HOT
|
return RedditTypes.SortType.HOT
|
||||||
|
|
||||||
@staticmethod
|
def _create_download_filter(self) -> DownloadFilter:
|
||||||
def _create_download_filter(args: argparse.Namespace) -> DownloadFilter:
|
|
||||||
formats = {
|
formats = {
|
||||||
"videos": [".mp4", ".webm"],
|
"videos": [".mp4", ".webm"],
|
||||||
"images": [".jpg", ".jpeg", ".png", ".bmp"],
|
"images": [".jpg", ".jpeg", ".png", ".bmp"],
|
||||||
"gifs": [".gif"],
|
"gifs": [".gif"],
|
||||||
"self": []
|
"self": []
|
||||||
}
|
}
|
||||||
excluded_extensions = [extension for ext_type in args.skip for extension in formats.get(ext_type, ())]
|
excluded_extensions = [extension for ext_type in self.args.skip for extension in formats.get(ext_type, ())]
|
||||||
return DownloadFilter(excluded_extensions, args.skip_domain)
|
return DownloadFilter(excluded_extensions, self.args.skip_domain)
|
||||||
|
|
||||||
def download(self):
|
def download(self):
|
||||||
for generator in self.reddit_lists:
|
for generator in self.reddit_lists:
|
||||||
|
|||||||
Reference in New Issue
Block a user