From b37ff0714f899395145d140a8d4ec19636a22f5a Mon Sep 17 00:00:00 2001 From: Serene <33189705+Serene-Arc@users.noreply.github.com> Date: Sun, 18 Apr 2021 21:24:11 +1000 Subject: [PATCH] Fix time filters (#279) --- bdfr/downloader.py | 52 ++++++++------ bdfr/resource.py | 6 +- bdfr/site_downloaders/download_factory.py | 35 ++++++---- .../site_downloaders/test_download_factory.py | 11 +++ bdfr/tests/test_downloader.py | 70 +++++++++++++------ bdfr/tests/test_integration.py | 1 - bdfr/tests/test_resource.py | 8 ++- 7 files changed, 121 insertions(+), 62 deletions(-) diff --git a/bdfr/downloader.py b/bdfr/downloader.py index 4897831..c24b5cd 100644 --- a/bdfr/downloader.py +++ b/bdfr/downloader.py @@ -41,19 +41,20 @@ def _calc_hash(existing_file: Path): class RedditTypes: class SortType(Enum): - HOT = auto() - RISING = auto() CONTROVERSIAL = auto() + HOT = auto() NEW = auto() RELEVENCE = auto() + RISING = auto() + TOP = auto() class TimeType(Enum): - HOUR = auto() - DAY = auto() - WEEK = auto() - MONTH = auto() - YEAR = auto() - ALL = auto() + ALL = 'all' + DAY = 'day' + HOUR = 'hour' + MONTH = 'month' + WEEK = 'week' + YEAR = 'year' class RedditDownloader: @@ -229,16 +230,16 @@ class RedditDownloader: try: reddit = self.reddit_instance.subreddit(reddit) if self.args.search: - out.append( - reddit.search( - self.args.search, - sort=self.sort_filter.name.lower(), - limit=self.args.limit, - )) + out.append(reddit.search( + self.args.search, + sort=self.sort_filter.name.lower(), + limit=self.args.limit, + time_filter=self.time_filter.value, + )) logger.debug( f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"') else: - out.append(sort_function(reddit, limit=self.args.limit)) + out.append(self._create_filtered_listing_generator(reddit)) logger.debug(f'Added submissions from subreddit {reddit}') except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') @@ -271,6 +272,8 @@ class RedditDownloader: sort_function = praw.models.Subreddit.rising elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL: sort_function = praw.models.Subreddit.controversial + elif self.sort_filter is RedditTypes.SortType.TOP: + sort_function = praw.models.Subreddit.top else: sort_function = praw.models.Subreddit.hot return sort_function @@ -278,13 +281,12 @@ class RedditDownloader: def _get_multireddits(self) -> list[Iterator]: if self.args.multireddit: out = [] - sort_function = self._determine_sort_function() for multi in self._split_args_input(self.args.multireddit): try: multi = self.reddit_instance.multireddit(self.args.user, multi) if not multi.subreddits: raise errors.BulkDownloaderException - out.append(sort_function(multi, limit=self.args.limit)) + out.append(self._create_filtered_listing_generator(multi)) logger.debug(f'Added submissions from multireddit {multi}') except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e: logger.error(f'Failed to get submissions for multireddit {multi}: {e}') @@ -292,6 +294,13 @@ class RedditDownloader: else: return [] + def _create_filtered_listing_generator(self, reddit_source) -> Iterator: + sort_function = self._determine_sort_function() + if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL): + return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value) + else: + return sort_function(reddit_source, limit=self.args.limit) + def _get_user_data(self) -> list[Iterator]: if any([self.args.submitted, self.args.upvoted, self.args.saved]): if self.args.user: @@ -299,14 +308,11 @@ class RedditDownloader: logger.error(f'User {self.args.user} does not exist') return [] generators = [] - sort_function = self._determine_sort_function() if self.args.submitted: logger.debug(f'Retrieving submitted posts of user {self.args.user}') - generators.append( - sort_function( - self.reddit_instance.redditor(self.args.user).submissions, - limit=self.args.limit, - )) + generators.append(self._create_filtered_listing_generator( + self.reddit_instance.redditor(self.args.user).submissions, + )) if not self.authenticated and any((self.args.upvoted, self.args.saved)): logger.warning('Accessing user lists requires authentication') else: diff --git a/bdfr/resource.py b/bdfr/resource.py index be6aaaf..966f5ba 100644 --- a/bdfr/resource.py +++ b/bdfr/resource.py @@ -6,6 +6,7 @@ import logging import re import time from typing import Optional +import urllib.parse import _hashlib import requests @@ -64,7 +65,8 @@ class Resource: self.hash = hashlib.md5(self.content) def _determine_extension(self) -> Optional[str]: - extension_pattern = re.compile(r'.*(\..{3,5})(?:\?.*)?(?:#.*)?$') - match = re.search(extension_pattern, self.url) + extension_pattern = re.compile(r'.*(\..{3,5})$') + stripped_url = urllib.parse.urlsplit(self.url).path + match = re.search(extension_pattern, stripped_url) if match: return match.group(1) diff --git a/bdfr/site_downloaders/download_factory.py b/bdfr/site_downloaders/download_factory.py index 8a39413..4bd6225 100644 --- a/bdfr/site_downloaders/download_factory.py +++ b/bdfr/site_downloaders/download_factory.py @@ -2,6 +2,7 @@ # coding=utf-8 import re +import urllib.parse from typing import Type from bdfr.exceptions import NotADownloadableLinkError @@ -21,30 +22,38 @@ from bdfr.site_downloaders.youtube import Youtube class DownloadFactory: @staticmethod def pull_lever(url: str) -> Type[BaseDownloader]: - url_beginning = r'\s*(https?://(www\.)?)' - if re.match(url_beginning + r'(i\.)?imgur.*\.gifv$', url): + sanitised_url = DownloadFactory._sanitise_url(url) + if re.match(r'(i\.)?imgur.*\.gifv$', sanitised_url): return Imgur - elif re.match(url_beginning + r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', url): + elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url): return Direct - elif re.match(url_beginning + r'erome\.com.*', url): + elif re.match(r'erome\.com.*', sanitised_url): return Erome - elif re.match(url_beginning + r'reddit\.com/gallery/.*', url): + elif re.match(r'reddit\.com/gallery/.*', sanitised_url): return Gallery - elif re.match(url_beginning + r'gfycat\.', url): + elif re.match(r'gfycat\.', sanitised_url): return Gfycat - elif re.match(url_beginning + r'gifdeliverynetwork', url): + elif re.match(r'gifdeliverynetwork', sanitised_url): return GifDeliveryNetwork - elif re.match(url_beginning + r'(m\.)?imgur.*', url): + elif re.match(r'(m\.)?imgur.*', sanitised_url): return Imgur - elif re.match(url_beginning + r'redgifs.com', url): + elif re.match(r'redgifs.com', sanitised_url): return Redgifs - elif re.match(url_beginning + r'reddit\.com/r/', url): + elif re.match(r'reddit\.com/r/', sanitised_url): return SelfPost - elif re.match(url_beginning + r'v\.redd\.it', url): + elif re.match(r'v\.redd\.it', sanitised_url): return VReddit - elif re.match(url_beginning + r'(m\.)?youtu\.?be', url): + elif re.match(r'(m\.)?youtu\.?be', sanitised_url): return Youtube - elif re.match(url_beginning + r'i\.redd\.it.*', url): + elif re.match(r'i\.redd\.it.*', sanitised_url): return Direct else: raise NotADownloadableLinkError(f'No downloader module exists for url {url}') + + @staticmethod + def _sanitise_url(url: str) -> str: + beginning_regex = re.compile(r'\s*(www\.?)?') + split_url = urllib.parse.urlsplit(url) + split_url = split_url.netloc + split_url.path + split_url = re.sub(beginning_regex, '', split_url) + return split_url diff --git a/bdfr/tests/site_downloaders/test_download_factory.py b/bdfr/tests/site_downloaders/test_download_factory.py index 5d6260e..65625b7 100644 --- a/bdfr/tests/site_downloaders/test_download_factory.py +++ b/bdfr/tests/site_downloaders/test_download_factory.py @@ -58,3 +58,14 @@ def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownlo def test_factory_lever_bad(test_url: str): with pytest.raises(NotADownloadableLinkError): DownloadFactory.pull_lever(test_url) + + +@pytest.mark.parametrize(('test_url', 'expected'), ( + ('www.test.com/test.png', 'test.com/test.png'), + ('www.test.com/test.png?test_value=random', 'test.com/test.png'), + ('https://youtube.com/watch?v=Gv8Wz74FjVA', 'youtube.com/watch'), + ('https://i.imgur.com/BuzvZwb.gifv', 'i.imgur.com/BuzvZwb.gifv'), +)) +def test_sanitise_urll(test_url: str, expected: str): + result = DownloadFactory._sanitise_url(test_url) + assert result == expected diff --git a/bdfr/tests/test_downloader.py b/bdfr/tests/test_downloader.py index 9a4f051..0d609ef 100644 --- a/bdfr/tests/test_downloader.py +++ b/bdfr/tests/test_downloader.py @@ -148,54 +148,71 @@ def test_get_submissions_from_link( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddits', 'limit'), ( - (('Futurology',), 10), - (('Futurology', 'Mindustry, Python'), 10), - (('Futurology',), 20), - (('Futurology', 'Python'), 10), - (('Futurology',), 100), - (('Futurology',), 0), +@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), ( + (('Futurology',), 10, 'hot', 'all', 10), + (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30), + (('Futurology',), 20, 'hot', 'all', 20), + (('Futurology', 'Python'), 10, 'hot', 'all', 20), + (('Futurology',), 100, 'hot', 'all', 100), + (('Futurology',), 0, 'hot', 'all', 0), + (('Futurology',), 10, 'top', 'all', 10), + (('Futurology',), 10, 'top', 'week', 10), + (('Futurology',), 10, 'hot', 'week', 10), )) def test_get_subreddit_normal( test_subreddits: list[str], limit: int, + sort_type: str, + time_filter: str, + max_expected_len: int, downloader_mock: MagicMock, - reddit_instance: praw.Reddit): + reddit_instance: praw.Reddit, +): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.args.limit = limit + downloader_mock.args.sort = sort_type downloader_mock.args.subreddit = test_subreddits downloader_mock.reddit_instance = reddit_instance - downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.sort_filter = RedditDownloader._create_sort_filter(downloader_mock) results = RedditDownloader._get_subreddits(downloader_mock) test_subreddits = downloader_mock._split_args_input(test_subreddits) - results = assert_all_results_are_submissions( - (limit * len(test_subreddits)) if limit else None, results) + results = [sub for res1 in results for sub in res1] + assert all([isinstance(res1, praw.models.Submission) for res1 in results]) assert all([res.subreddit.display_name in test_subreddits for res in results]) + assert len(results) <= max_expected_len @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit'), ( - (('Python',), 'scraper', 10), - (('Python',), '', 10), - (('Python',), 'djsdsgewef', 0), +@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), ( + (('Python',), 'scraper', 10, 'all', 10), + (('Python',), '', 10, 'all', 10), + (('Python',), 'djsdsgewef', 10, 'all', 0), + (('Python',), 'scraper', 10, 'year', 10), + (('Python',), 'scraper', 10, 'hour', 1), )) def test_get_subreddit_search( test_subreddits: list[str], search_term: str, + time_filter: str, limit: int, + max_expected_len: int, downloader_mock: MagicMock, - reddit_instance: praw.Reddit): + reddit_instance: praw.Reddit, +): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.args.limit = limit downloader_mock.args.search = search_term downloader_mock.args.subreddit = test_subreddits downloader_mock.reddit_instance = reddit_instance downloader_mock.sort_filter = RedditTypes.SortType.HOT + downloader_mock.args.time = time_filter + downloader_mock.time_filter = RedditDownloader._create_time_filter(downloader_mock) results = RedditDownloader._get_subreddits(downloader_mock) - results = assert_all_results_are_submissions( - (limit * len(test_subreddits)) if limit else None, results) + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) assert all([res.subreddit.display_name in test_subreddits for res in results]) + assert len(results) <= max_expected_len @pytest.mark.online @@ -210,15 +227,23 @@ def test_get_multireddits_public( test_multireddits: list[str], limit: int, reddit_instance: praw.Reddit, - downloader_mock: MagicMock): + downloader_mock: MagicMock, +): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT downloader_mock.args.limit = limit downloader_mock.args.multireddit = test_multireddits downloader_mock.args.user = test_user downloader_mock.reddit_instance = reddit_instance + downloader_mock._create_filtered_listing_generator.return_value = \ + RedditDownloader._create_filtered_listing_generator( + downloader_mock, + reddit_instance.multireddit(test_user, test_multireddits[0]), + ) results = RedditDownloader._get_multireddits(downloader_mock) - assert_all_results_are_submissions((limit * len(test_multireddits)) if limit else None, results) + results = [sub for res in results for sub in res] + assert all([isinstance(res, praw.models.Submission) for res in results]) + assert len(results) == limit @pytest.mark.online @@ -236,6 +261,11 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic downloader_mock.args.user = test_user downloader_mock.authenticated = False downloader_mock.reddit_instance = reddit_instance + downloader_mock._create_filtered_listing_generator.return_value = \ + RedditDownloader._create_filtered_listing_generator( + downloader_mock, + reddit_instance.redditor(test_user).submissions, + ) results = RedditDownloader._get_user_data(downloader_mock) results = assert_all_results_are_submissions(limit, results) assert all([res.author.name == test_user for res in results]) diff --git a/bdfr/tests/test_integration.py b/bdfr/tests/test_integration.py index 47b5229..396025b 100644 --- a/bdfr/tests/test_integration.py +++ b/bdfr/tests/test_integration.py @@ -101,7 +101,6 @@ def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Pa ['--user', 'djnish', '--submitted', '-L', 10], ['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'], ['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'], - ['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial', '--time', 'month'], )) def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path): runner = CliRunner() diff --git a/bdfr/tests/test_resource.py b/bdfr/tests/test_resource.py index de6030b..272c457 100644 --- a/bdfr/tests/test_resource.py +++ b/bdfr/tests/test_resource.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # coding=utf-8 -import pytest from unittest.mock import MagicMock +import pytest + from bdfr.resource import Resource @@ -15,8 +16,9 @@ from bdfr.resource import Resource ('https://www.resource.com/test/example.jpg', '.jpg'), ('hard.png.mp4', '.mp4'), ('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'), - ('test.jpg#test','.jpg'), - ('test.jpg?width=247#test','.jpg'), + ('test.jpg#test', '.jpg'), + ('test.jpg?width=247#test', '.jpg'), + ('https://www.test.com/test/test2/example.png?random=test#thing', '.png'), )) def test_resource_get_extension(test_url: str, expected: str): test_resource = Resource(MagicMock(), test_url)