diff --git a/bulkredditdownloader/download_filter.py b/bulkredditdownloader/download_filter.py new file mode 100644 index 0000000..806fd0d --- /dev/null +++ b/bulkredditdownloader/download_filter.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +import re + + +class DownloadFilter: + def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None): + self.excluded_extensions = excluded_extensions + self.excluded_domains = excluded_domains + + def check_url(self, url: str) -> bool: + """Return whether a URL is allowed or not""" + if not self._check_extension(url): + return False + elif not self._check_domain(url): + return False + else: + return True + + def _check_extension(self, url: str) -> bool: + if not self.excluded_extensions: + return True + combined_extensions = '|'.join(self.excluded_extensions) + pattern = re.compile(r'.*({})$'.format(combined_extensions)) + if re.match(pattern, url): + return False + else: + return True + + def _check_domain(self, url: str) -> bool: + if not self.excluded_domains: + return True + combined_domains = '|'.join(self.excluded_domains) + pattern = re.compile(r'https?://.*({}).*'.format(combined_domains)) + if re.match(pattern, url): + return False + else: + return True diff --git a/bulkredditdownloader/tests/test_download_filter.py b/bulkredditdownloader/tests/test_download_filter.py new file mode 100644 index 0000000..c8957a5 --- /dev/null +++ b/bulkredditdownloader/tests/test_download_filter.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +import pytest + +from bulkredditdownloader.download_filter import DownloadFilter + + +@pytest.fixture() +def download_filter() -> DownloadFilter: + return DownloadFilter(['mp4', 'mp3'], ['test.com', 'reddit.com']) + + +@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', False), + ('test.avi', True), + ('test.random.mp3', False) + )) +def test_filter_extension(test_url: str, expected: bool, download_filter: DownloadFilter): + result = download_filter._check_extension(test_url) + assert result == expected + + +@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', True), + ('http://reddit.com/test.mp4', False), + ('http://reddit.com/test.gif', False), + ('https://www.example.com/test.mp4', True), + ('https://www.example.com/test.png', True), + )) +def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadFilter): + result = download_filter._check_domain(test_url) + assert result == expected + + +@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', False), + ('test.gif', True), + ('https://www.example.com/test.mp4', False), + ('https://www.example.com/test.png', True), + ('http://reddit.com/test.mp4', False), + ('http://reddit.com/test.gif', False), + )) +def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter): + result = download_filter.check_url(test_url) + assert result == expected + + +@pytest.mark.parametrize('test_url', ('test.mp3', + 'test.mp4', + 'http://reddit.com/test.mp4', + 't', + )) +def test_filter_empty_filter(test_url: str): + download_filter = DownloadFilter() + result = download_filter.check_url(test_url) + assert result is True