diff --git a/bulkredditdownloader/file_name_formatter.py b/bulkredditdownloader/file_name_formatter.py index 461947e..6f8fbd6 100644 --- a/bulkredditdownloader/file_name_formatter.py +++ b/bulkredditdownloader/file_name_formatter.py @@ -6,12 +6,19 @@ from pathlib import Path import praw.models +from bulkredditdownloader.errors import BulkDownloaderException from bulkredditdownloader.resource import Resource class FileNameFormatter: + key_terms = ('title', 'subreddit', 'redditor', 'postid', 'upvotes', 'flair', 'date') + def __init__(self, file_format_string: str, directory_format_string: str): + if not self.validate_string(file_format_string): + raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string') self.file_format_string = file_format_string + if not self.validate_string(directory_format_string): + raise BulkDownloaderException(f'"{directory_format_string}" is not a valid format string') self.directory_format_string = directory_format_string @staticmethod @@ -38,3 +45,9 @@ class FileNameFormatter: file_path = subfolder / (str(self._format_name(resource.source_submission, self.file_format_string)) + resource.extension) return file_path + + @staticmethod + def validate_string(test_string: str) -> bool: + if not test_string: + return False + return any([f'{{{key}}}' in test_string.lower() for key in FileNameFormatter.key_terms]) diff --git a/bulkredditdownloader/tests/test_file_name_formatter.py b/bulkredditdownloader/tests/test_file_name_formatter.py index 7d18dd2..eb679d3 100644 --- a/bulkredditdownloader/tests/test_file_name_formatter.py +++ b/bulkredditdownloader/tests/test_file_name_formatter.py @@ -42,6 +42,20 @@ def test_format_name_mock(format_string: str, expected: str, submission: Mock): assert result == expected +@pytest.mark.parametrize(('test_string', 'expected'), ( + ('', False), + ('test', False), + ('{POSTID}', True), + ('POSTID', False), + ('{POSTID}_test', True), + ('test_{TITLE}', True), + ('TITLE_POSTID', False), +)) +def test_check_format_string_validity(test_string: str, expected: bool): + result = FileNameFormatter.validate_string(test_string) + assert result == expected + + @pytest.mark.online @pytest.mark.reddit @pytest.mark.parametrize(('format_string', 'expected'),