Fix time filters (#279)
This commit is contained in:
@@ -41,19 +41,20 @@ def _calc_hash(existing_file: Path):
|
||||
|
||||
class RedditTypes:
|
||||
class SortType(Enum):
|
||||
HOT = auto()
|
||||
RISING = auto()
|
||||
CONTROVERSIAL = auto()
|
||||
HOT = auto()
|
||||
NEW = auto()
|
||||
RELEVENCE = auto()
|
||||
RISING = auto()
|
||||
TOP = auto()
|
||||
|
||||
class TimeType(Enum):
|
||||
HOUR = auto()
|
||||
DAY = auto()
|
||||
WEEK = auto()
|
||||
MONTH = auto()
|
||||
YEAR = auto()
|
||||
ALL = auto()
|
||||
ALL = 'all'
|
||||
DAY = 'day'
|
||||
HOUR = 'hour'
|
||||
MONTH = 'month'
|
||||
WEEK = 'week'
|
||||
YEAR = 'year'
|
||||
|
||||
|
||||
class RedditDownloader:
|
||||
@@ -229,16 +230,16 @@ class RedditDownloader:
|
||||
try:
|
||||
reddit = self.reddit_instance.subreddit(reddit)
|
||||
if self.args.search:
|
||||
out.append(
|
||||
reddit.search(
|
||||
self.args.search,
|
||||
sort=self.sort_filter.name.lower(),
|
||||
limit=self.args.limit,
|
||||
))
|
||||
out.append(reddit.search(
|
||||
self.args.search,
|
||||
sort=self.sort_filter.name.lower(),
|
||||
limit=self.args.limit,
|
||||
time_filter=self.time_filter.value,
|
||||
))
|
||||
logger.debug(
|
||||
f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"')
|
||||
else:
|
||||
out.append(sort_function(reddit, limit=self.args.limit))
|
||||
out.append(self._create_filtered_listing_generator(reddit))
|
||||
logger.debug(f'Added submissions from subreddit {reddit}')
|
||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e:
|
||||
logger.error(f'Failed to get submissions for subreddit {reddit}: {e}')
|
||||
@@ -271,6 +272,8 @@ class RedditDownloader:
|
||||
sort_function = praw.models.Subreddit.rising
|
||||
elif self.sort_filter is RedditTypes.SortType.CONTROVERSIAL:
|
||||
sort_function = praw.models.Subreddit.controversial
|
||||
elif self.sort_filter is RedditTypes.SortType.TOP:
|
||||
sort_function = praw.models.Subreddit.top
|
||||
else:
|
||||
sort_function = praw.models.Subreddit.hot
|
||||
return sort_function
|
||||
@@ -278,13 +281,12 @@ class RedditDownloader:
|
||||
def _get_multireddits(self) -> list[Iterator]:
|
||||
if self.args.multireddit:
|
||||
out = []
|
||||
sort_function = self._determine_sort_function()
|
||||
for multi in self._split_args_input(self.args.multireddit):
|
||||
try:
|
||||
multi = self.reddit_instance.multireddit(self.args.user, multi)
|
||||
if not multi.subreddits:
|
||||
raise errors.BulkDownloaderException
|
||||
out.append(sort_function(multi, limit=self.args.limit))
|
||||
out.append(self._create_filtered_listing_generator(multi))
|
||||
logger.debug(f'Added submissions from multireddit {multi}')
|
||||
except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e:
|
||||
logger.error(f'Failed to get submissions for multireddit {multi}: {e}')
|
||||
@@ -292,6 +294,13 @@ class RedditDownloader:
|
||||
else:
|
||||
return []
|
||||
|
||||
def _create_filtered_listing_generator(self, reddit_source) -> Iterator:
|
||||
sort_function = self._determine_sort_function()
|
||||
if self.sort_filter in (RedditTypes.SortType.TOP, RedditTypes.SortType.CONTROVERSIAL):
|
||||
return sort_function(reddit_source, limit=self.args.limit, time_filter=self.time_filter.value)
|
||||
else:
|
||||
return sort_function(reddit_source, limit=self.args.limit)
|
||||
|
||||
def _get_user_data(self) -> list[Iterator]:
|
||||
if any([self.args.submitted, self.args.upvoted, self.args.saved]):
|
||||
if self.args.user:
|
||||
@@ -299,14 +308,11 @@ class RedditDownloader:
|
||||
logger.error(f'User {self.args.user} does not exist')
|
||||
return []
|
||||
generators = []
|
||||
sort_function = self._determine_sort_function()
|
||||
if self.args.submitted:
|
||||
logger.debug(f'Retrieving submitted posts of user {self.args.user}')
|
||||
generators.append(
|
||||
sort_function(
|
||||
self.reddit_instance.redditor(self.args.user).submissions,
|
||||
limit=self.args.limit,
|
||||
))
|
||||
generators.append(self._create_filtered_listing_generator(
|
||||
self.reddit_instance.redditor(self.args.user).submissions,
|
||||
))
|
||||
if not self.authenticated and any((self.args.upvoted, self.args.saved)):
|
||||
logger.warning('Accessing user lists requires authentication')
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user