Add download filter class
This commit is contained in:
39
bulkredditdownloader/download_filter.py
Normal file
39
bulkredditdownloader/download_filter.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadFilter:
|
||||||
|
def __init__(self, excluded_extensions: list[str] = None, excluded_domains: list[str] = None):
|
||||||
|
self.excluded_extensions = excluded_extensions
|
||||||
|
self.excluded_domains = excluded_domains
|
||||||
|
|
||||||
|
def check_url(self, url: str) -> bool:
|
||||||
|
"""Return whether a URL is allowed or not"""
|
||||||
|
if not self._check_extension(url):
|
||||||
|
return False
|
||||||
|
elif not self._check_domain(url):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _check_extension(self, url: str) -> bool:
|
||||||
|
if not self.excluded_extensions:
|
||||||
|
return True
|
||||||
|
combined_extensions = '|'.join(self.excluded_extensions)
|
||||||
|
pattern = re.compile(r'.*({})$'.format(combined_extensions))
|
||||||
|
if re.match(pattern, url):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _check_domain(self, url: str) -> bool:
|
||||||
|
if not self.excluded_domains:
|
||||||
|
return True
|
||||||
|
combined_domains = '|'.join(self.excluded_domains)
|
||||||
|
pattern = re.compile(r'https?://.*({}).*'.format(combined_domains))
|
||||||
|
if re.match(pattern, url):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
54
bulkredditdownloader/tests/test_download_filter.py
Normal file
54
bulkredditdownloader/tests/test_download_filter.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from bulkredditdownloader.download_filter import DownloadFilter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def download_filter() -> DownloadFilter:
|
||||||
|
return DownloadFilter(['mp4', 'mp3'], ['test.com', 'reddit.com'])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', False),
|
||||||
|
('test.avi', True),
|
||||||
|
('test.random.mp3', False)
|
||||||
|
))
|
||||||
|
def test_filter_extension(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||||
|
result = download_filter._check_extension(test_url)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', True),
|
||||||
|
('http://reddit.com/test.mp4', False),
|
||||||
|
('http://reddit.com/test.gif', False),
|
||||||
|
('https://www.example.com/test.mp4', True),
|
||||||
|
('https://www.example.com/test.png', True),
|
||||||
|
))
|
||||||
|
def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||||
|
result = download_filter._check_domain(test_url)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(('test_url', 'expected'), (('test.mp4', False),
|
||||||
|
('test.gif', True),
|
||||||
|
('https://www.example.com/test.mp4', False),
|
||||||
|
('https://www.example.com/test.png', True),
|
||||||
|
('http://reddit.com/test.mp4', False),
|
||||||
|
('http://reddit.com/test.gif', False),
|
||||||
|
))
|
||||||
|
def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter):
|
||||||
|
result = download_filter.check_url(test_url)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('test_url', ('test.mp3',
|
||||||
|
'test.mp4',
|
||||||
|
'http://reddit.com/test.mp4',
|
||||||
|
't',
|
||||||
|
))
|
||||||
|
def test_filter_empty_filter(test_url: str):
|
||||||
|
download_filter = DownloadFilter()
|
||||||
|
result = download_filter.check_url(test_url)
|
||||||
|
assert result is True
|
||||||
Reference in New Issue
Block a user