diff --git a/movienightbot/db/controllers.py b/movienightbot/db/controllers.py index 8bcc19a..48f73bf 100644 --- a/movienightbot/db/controllers.py +++ b/movienightbot/db/controllers.py @@ -1,4 +1,5 @@ import logging +import random from collections import defaultdict from string import ascii_lowercase from typing import Any, Optional, Union @@ -129,15 +130,21 @@ def update_imdb_id(self, server_id: int, movie_name: str, imdb_id: str): .execute() ) + def _weighted_movie_selection(self, server_id: int, num_movies: int) -> list[int]: + # weighted score so it's random but has the movie avg score taken into account + return ( + Movie.select(Movie, ((Movie.total_score / Movie.num_votes_entered) * pw.fn.Random()).alias('weight_score')) + .order_by(pw.SQL("weight_score").desc()) + .where( + (Movie.server == server_id) + & Movie.watched_on.is_null(True) + & (Movie.num_votes_entered > 0) + ) + .limit(num_movies) + ) + def get_random_movies(self, server_id: int, num_movies: int, genres: Optional[list[str]] = None) -> list[Movie]: - if genres is None: - return ( - Movie.select() - .order_by(pw.fn.Random()) - .where((Movie.server == server_id) & Movie.watched_on.is_null(True)) - .limit(num_movies) - ) - else: + if genres is not None: return ( Movie.select() .join(MovieGenre) @@ -148,6 +155,40 @@ def get_random_movies(self, server_id: int, num_movies: int, genres: Optional[li .limit(num_movies) ) + # calculate the new vs old split + counts_by_votes_entered = ( + Movie.select(Movie.num_votes_entered, pw.fn.COUNT(Movie.num_votes_entered).alias("count_value")) + .group_by(Movie.num_votes_entered) + .where(Movie.server == server_id) + .order_by(Movie.num_votes_entered.asc()) + ) + if counts_by_votes_entered[0].num_votes_entered != 0: + # We have some weird case where every movie has been in at least one vote + # so short circuit and just choose at random weighted + return self._weighted_movie_selection(server_id, num_movies) + + total_count = sum(c.count_value for c in counts_by_votes_entered) + new_count = counts_by_votes_entered[0].count_value + + # prevent errors by limiting the total number if db has too few options + if total_count < num_movies: + num_movies = total_count + + # calculate the actual split counts, making sure at least one new movie is in the split + new_split = max(1, round((new_count / total_count) * num_movies)) + old_split = num_movies - new_split + + # query the movies based on the split + new_movies = (Movie.select() + .order_by(pw.fn.Random()) + .where((Movie.server == server_id) + & Movie.watched_on.is_null(True) + & (Movie.num_votes_entered == 0)) + .limit(new_split)) + + # Shuffle so it isn't ordered new -> old always + return random.shuffle(new_movies + self._weighted_movie_selection(server_id, old_split)) + def movie_score_weightings(server_id: int) -> dict[int, float]: num_votes_allowed = ServerController().get_by_id(server_id).num_votes_per_user