From 7c2b7b0e83ed4002812131a1033712a55ce097d6 Mon Sep 17 00:00:00 2001 From: Serene-Arc Date: Mon, 8 Mar 2021 12:34:03 +1000 Subject: [PATCH] Move scope regex parsing --- bulkredditdownloader/oauth2.py | 8 +++++++- bulkredditdownloader/tests/test_oauth2.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/bulkredditdownloader/oauth2.py b/bulkredditdownloader/oauth2.py index 67444d8..c34b54a 100644 --- a/bulkredditdownloader/oauth2.py +++ b/bulkredditdownloader/oauth2.py @@ -4,6 +4,7 @@ import configparser import logging import random +import re import socket import praw @@ -21,7 +22,7 @@ class OAuth2Authenticator: self.scopes = wanted_scopes @staticmethod - def _check_scopes(wanted_scopes: list[str]): + def _check_scopes(wanted_scopes: set[str]): response = requests.get('https://www.reddit.com/api/v1/scopes.json', headers={'User-Agent': 'fetch-scopes test'}) known_scopes = [scope for scope, data in response.json().items()] @@ -30,6 +31,11 @@ class OAuth2Authenticator: if scope not in known_scopes: raise BulkDownloaderException(f'Scope {scope} is not known to reddit') + @staticmethod + def split_scopes(scopes: str) -> set[str]: + scopes = re.split(r'[,: ]+', scopes) + return set(scopes) + def retrieve_new_token(self) -> str: reddit = praw.Reddit(redirect_uri='http://localhost:8080', user_agent='obtain_refresh_token for BDFR') state = str(random.randint(0, 65000)) diff --git a/bulkredditdownloader/tests/test_oauth2.py b/bulkredditdownloader/tests/test_oauth2.py index a80d7a7..c5a63b7 100644 --- a/bulkredditdownloader/tests/test_oauth2.py +++ b/bulkredditdownloader/tests/test_oauth2.py @@ -4,7 +4,6 @@ import configparser from unittest.mock import MagicMock -import praw import pytest from bulkredditdownloader.exceptions import BulkDownloaderException @@ -30,6 +29,17 @@ def test_check_scopes(test_scopes: list[str]): OAuth2Authenticator._check_scopes(test_scopes) +@pytest.mark.parametrize(('test_scopes', 'expected'), ( + ('history', {'history', }), + ('history creddits', {'history', 'creddits'}), + ('history, creddits, account', {'history', 'creddits', 'account'}), + ('history,creddits,account,flair', {'history', 'creddits', 'account', 'flair'}), +)) +def test_split_scopes(test_scopes: str, expected: set[str]): + result = OAuth2Authenticator.split_scopes(test_scopes) + assert result == expected + + @pytest.mark.online @pytest.mark.parametrize('test_scopes', ( ('random',),