Format according to the black standard
This commit is contained in:
@@ -17,7 +17,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuth2Authenticator:
|
||||
|
||||
def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
|
||||
self._check_scopes(wanted_scopes)
|
||||
self.scopes = wanted_scopes
|
||||
@@ -26,39 +25,41 @@ class OAuth2Authenticator:
|
||||
|
||||
@staticmethod
|
||||
def _check_scopes(wanted_scopes: set[str]):
|
||||
response = requests.get('https://www.reddit.com/api/v1/scopes.json',
|
||||
headers={'User-Agent': 'fetch-scopes test'})
|
||||
response = requests.get(
|
||||
"https://www.reddit.com/api/v1/scopes.json", headers={"User-Agent": "fetch-scopes test"}
|
||||
)
|
||||
known_scopes = [scope for scope, data in response.json().items()]
|
||||
known_scopes.append('*')
|
||||
known_scopes.append("*")
|
||||
for scope in wanted_scopes:
|
||||
if scope not in known_scopes:
|
||||
raise BulkDownloaderException(f'Scope {scope} is not known to reddit')
|
||||
raise BulkDownloaderException(f"Scope {scope} is not known to reddit")
|
||||
|
||||
@staticmethod
|
||||
def split_scopes(scopes: str) -> set[str]:
|
||||
scopes = re.split(r'[,: ]+', scopes)
|
||||
scopes = re.split(r"[,: ]+", scopes)
|
||||
return set(scopes)
|
||||
|
||||
def retrieve_new_token(self) -> str:
|
||||
reddit = praw.Reddit(
|
||||
redirect_uri='http://localhost:7634',
|
||||
user_agent='obtain_refresh_token for BDFR',
|
||||
redirect_uri="http://localhost:7634",
|
||||
user_agent="obtain_refresh_token for BDFR",
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret)
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
state = str(random.randint(0, 65000))
|
||||
url = reddit.auth.url(self.scopes, state, 'permanent')
|
||||
logger.warning('Authentication action required before the program can proceed')
|
||||
logger.warning(f'Authenticate at {url}')
|
||||
url = reddit.auth.url(self.scopes, state, "permanent")
|
||||
logger.warning("Authentication action required before the program can proceed")
|
||||
logger.warning(f"Authenticate at {url}")
|
||||
|
||||
client = self.receive_connection()
|
||||
data = client.recv(1024).decode('utf-8')
|
||||
param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&')
|
||||
params = {key: value for (key, value) in [token.split('=') for token in param_tokens]}
|
||||
data = client.recv(1024).decode("utf-8")
|
||||
param_tokens = data.split(" ", 2)[1].split("?", 1)[1].split("&")
|
||||
params = {key: value for (key, value) in [token.split("=") for token in param_tokens]}
|
||||
|
||||
if state != params['state']:
|
||||
if state != params["state"]:
|
||||
self.send_message(client)
|
||||
raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
|
||||
elif 'error' in params:
|
||||
elif "error" in params:
|
||||
self.send_message(client)
|
||||
raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
|
||||
|
||||
@@ -70,19 +71,19 @@ class OAuth2Authenticator:
|
||||
def receive_connection() -> socket.socket:
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(('0.0.0.0', 7634))
|
||||
logger.log(9, 'Server listening on 0.0.0.0:7634')
|
||||
server.bind(("0.0.0.0", 7634))
|
||||
logger.log(9, "Server listening on 0.0.0.0:7634")
|
||||
|
||||
server.listen(1)
|
||||
client = server.accept()[0]
|
||||
server.close()
|
||||
logger.log(9, 'Server closed')
|
||||
logger.log(9, "Server closed")
|
||||
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def send_message(client: socket.socket, message: str = ''):
|
||||
client.send(f'HTTP/1.1 200 OK\r\n\r\n{message}'.encode('utf-8'))
|
||||
def send_message(client: socket.socket, message: str = ""):
|
||||
client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode("utf-8"))
|
||||
client.close()
|
||||
|
||||
|
||||
@@ -94,14 +95,14 @@ class OAuth2TokenManager(praw.reddit.BaseTokenManager):
|
||||
|
||||
def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
|
||||
if authorizer.refresh_token is None:
|
||||
if self.config.has_option('DEFAULT', 'user_token'):
|
||||
authorizer.refresh_token = self.config.get('DEFAULT', 'user_token')
|
||||
logger.log(9, 'Loaded OAuth2 token for authoriser')
|
||||
if self.config.has_option("DEFAULT", "user_token"):
|
||||
authorizer.refresh_token = self.config.get("DEFAULT", "user_token")
|
||||
logger.log(9, "Loaded OAuth2 token for authoriser")
|
||||
else:
|
||||
raise RedditAuthenticationError('No auth token loaded in configuration')
|
||||
raise RedditAuthenticationError("No auth token loaded in configuration")
|
||||
|
||||
def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
|
||||
self.config.set('DEFAULT', 'user_token', authorizer.refresh_token)
|
||||
with open(self.config_location, 'w') as file:
|
||||
self.config.set("DEFAULT", "user_token", authorizer.refresh_token)
|
||||
with open(self.config_location, "w") as file:
|
||||
self.config.write(file, True)
|
||||
logger.log(9, f'Written OAuth2 token from authoriser to {self.config_location}')
|
||||
logger.log(9, f"Written OAuth2 token from authoriser to {self.config_location}")
|
||||
|
||||
Reference in New Issue
Block a user