Re-implement OAuth2

This commit is contained in:
Serene-Arc
2021-03-08 12:35:34 +10:00
committed by Ali Parlakci
parent 95876b3400
commit 36f516e3f0

View File

@@ -4,6 +4,7 @@
import argparse import argparse
import configparser import configparser
import logging import logging
import re
import socket import socket
from datetime import datetime from datetime import datetime
from enum import Enum, auto from enum import Enum, auto
@@ -18,6 +19,7 @@ import prawcore
import bulkredditdownloader.exceptions as errors import bulkredditdownloader.exceptions as errors
from bulkredditdownloader.download_filter import DownloadFilter from bulkredditdownloader.download_filter import DownloadFilter
from bulkredditdownloader.file_name_formatter import FileNameFormatter from bulkredditdownloader.file_name_formatter import FileNameFormatter
from bulkredditdownloader.oauth2 import OAuth2Authenticator, OAuth2TokenManager
from bulkredditdownloader.site_authenticator import SiteAuthenticator from bulkredditdownloader.site_authenticator import SiteAuthenticator
from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory from bulkredditdownloader.site_downloaders.download_factory import DownloadFactory
@@ -44,7 +46,7 @@ class RedditTypes:
class RedditDownloader: class RedditDownloader:
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
self.args = args self.args = args
self.config_directories = appdirs.AppDirs('bulk_reddit_downloader') self.config_directories = appdirs.AppDirs('bulk_reddit_downloader', 'BDFR')
self.run_time = datetime.now().isoformat() self.run_time = datetime.now().isoformat()
self._setup_internal_objects() self._setup_internal_objects()
@@ -55,19 +57,31 @@ class RedditDownloader:
self.time_filter = self._create_time_filter() self.time_filter = self._create_time_filter()
self.sort_filter = self._create_sort_filter() self.sort_filter = self._create_sort_filter()
self.file_name_formatter = self._create_file_name_formatter() self.file_name_formatter = self._create_file_name_formatter()
self.authenticator = self._create_authenticator() # self.authenticator = self._create_authenticator()
self._resolve_user_name() self._resolve_user_name()
self._determine_directories() self._determine_directories()
self._create_file_logger() self._create_file_logger()
self.master_hash_list = [] self.master_hash_list = []
self._load_config() self._load_config()
if self.cfg_parser.has_option('DEFAULT', 'reddit_token'):
# TODO: implement OAuth2 authentication self._create_reddit_instance()
def _create_reddit_instance(self):
if self.args.authenticate:
if not self.cfg_parser.has_option('DEFAULT', 'user_token'):
scopes = self.cfg_parser.get('DEFAULT', 'scopes')
scopes = OAuth2Authenticator.split_scopes(scopes)
oauth2_authenticator = OAuth2Authenticator(scopes)
token = oauth2_authenticator.retrieve_new_token()
self.cfg_parser['DEFAULT']['user_token'] = token
token_manager = OAuth2TokenManager(self.cfg_parser)
self.authenticated = True self.authenticated = True
self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'), self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'),
user_agent=socket.gethostname(), user_agent=socket.gethostname(),
) token_manager=token_manager)
else: else:
self.authenticated = False self.authenticated = False
self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'), self.reddit_instance = praw.Reddit(client_id=self.cfg_parser.get('DEFAULT', 'client_id'),
@@ -92,10 +106,15 @@ class RedditDownloader:
def _load_config(self): def _load_config(self):
self.cfg_parser = configparser.ConfigParser() self.cfg_parser = configparser.ConfigParser()
if self.args.use_local_config and Path('./config.cfg').exists(): possible_paths = [Path('./config.cfg'),
self.cfg_parser.read(Path('./config.cfg')) Path(self.config_directory, 'config.cfg'),
else: Path('./default_config.cfg'),
self.cfg_parser.read(Path('./default_config.cfg').resolve()) ]
for path in possible_paths:
if path.resolve().expanduser().exists():
self.config_location = path
break
self.cfg_parser.read(self.config_location)
def _create_file_logger(self): def _create_file_logger(self):
main_logger = logging.getLogger() main_logger = logging.getLogger()