From 36f516e3f0433bdebe7e86e90814d369dba7ae57 Mon Sep 17 00:00:00 2001 From: Serene-Arc Date: Mon, 8 Mar 2021 12:35:34 +1000 Subject: [PATCH] Re-implement OAuth2 --- bulkredditdownloader/downloader.py | 37 ++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/bulkredditdownloader/downloader.py b/bulkredditdownloader/downloader.py index 196dae4..d32572d 100644 --- a/bulkredditdownloader/downloader.py +++ b/bulkredditdownloader/downloader.py @@ -4,6 +4,7 @@ import argparse import configparser import logging +import re import socket from datetime import datetime from enum import Enum, auto @@ -18,6 +19,7 @@ import prawcore import bulkredditdownloader.exceptions as errors from bulkredditdownloader.download_filter import DownloadFilter from bulkredditdownloader.file_name_formatter import FileNameFormatter +from bulkredditdownloader.oauth2 import OAuth2Authenticator, OAuth2TokenManager from bulkredditdownloader.site_authenticator import SiteAuthenticator from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory @@ -44,7 +46,7 @@ class RedditTypes: class RedditDownloader: def __init__(self, args: argparse.Namespace): self.args = args - self.config_directories = appdirs.AppDirs('bulk_reddit_downloader') + self.config_directories = appdirs.AppDirs('bulk_reddit_downloader', 'BDFR') self.run_time = datetime.now().isoformat() self._setup_internal_objects() @@ -55,19 +57,31 @@ class RedditDownloader: self.time_filter = self._create_time_filter() self.sort_filter = self._create_sort_filter() self.file_name_formatter = self._create_file_name_formatter() - self.authenticator = self._create_authenticator() + # self.authenticator = self._create_authenticator() + self._resolve_user_name() self._determine_directories() self._create_file_logger() self.master_hash_list = [] self._load_config() - if self.cfg_parser.has_option('DEFAULT', 'reddit_token'): - # TODO: implement OAuth2 authentication + + self._create_reddit_instance() + + def _create_reddit_instance(self): + if self.args.authenticate: + if not self.cfg_parser.has_option('DEFAULT', 'user_token'): + scopes = self.cfg_parser.get('DEFAULT', 'scopes') + scopes = OAuth2Authenticator.split_scopes(scopes) + oauth2_authenticator = OAuth2Authenticator(scopes) + token = oauth2_authenticator.retrieve_new_token() + self.cfg_parser['DEFAULT']['user_token'] = token + token_manager = OAuth2TokenManager(self.cfg_parser) + self.authenticated = True self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'), client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), user_agent=socket.gethostname(), - ) + token_manager=token_manager) else: self.authenticated = False self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'), @@ -92,10 +106,15 @@ class RedditDownloader: def _load_config(self): self.cfg_parser = configparser.ConfigParser() - if self.args.use_local_config and Path('./config.cfg').exists(): - self.cfg_parser.read(Path('./config.cfg')) - else: - self.cfg_parser.read(Path('./default_config.cfg').resolve()) + possible_paths = [Path('./config.cfg'), + Path(self.config_directory, 'config.cfg'), + Path('./default_config.cfg'), + ] + for path in possible_paths: + if path.resolve().expanduser().exists(): + self.config_location = path + break + self.cfg_parser.read(self.config_location) def _create_file_logger(self): main_logger = logging.getLogger()