diff --git a/.gitignore b/.gitignore index 4c4a9be3..e4201265 100644 --- a/.gitignore +++ b/.gitignore @@ -266,3 +266,4 @@ target/ targets.json web/static/sw.js !web/pages/* +.vscode/settings.json diff --git a/classes/_state.py b/classes/_state.py new file mode 100644 index 00000000..3935e202 --- /dev/null +++ b/classes/_state.py @@ -0,0 +1,808 @@ +import asyncio +import copy +import datetime +import inspect +import logging + +import orjson + +from discord import utils +from discord.emoji import Emoji +from discord.enums import ChannelType, try_enum +from discord.invite import Invite +from discord.member import VoiceState +from discord.partial_emoji import PartialEmoji +from discord.raw_models import ( + RawBulkMessageDeleteEvent, + RawMessageDeleteEvent, + RawMessageUpdateEvent, + RawReactionActionEvent, + RawReactionClearEmojiEvent, + RawReactionClearEvent, +) +from discord.reaction import Reaction +from discord.role import Role +from discord.user import ClientUser, User + +from discord.channel import DMChannel, TextChannel, _channel_factory +from discord.guild import Guild +from discord.member import Member +from discord.message import Message + +log = logging.getLogger(__name__) + + +class State: + def __init__( + self, *, dispatch, handlers, hooks, http, loop, redis=None, shard_count=None, id, **options + ): + self.dispatch = dispatch + self.handlers = handlers + self.hooks = hooks + self.http = http + self.loop = loop + self.redis = redis + self.shard_count = shard_count + self.id = id + + self._ready_task = None + self._ready_state = None + self._ready_timeout = options.get("guild_ready_timeout", 2.0) + + self._voice_clients = {} + self._private_channels_by_user = {} + + self.allowed_mentions = options.get("allowed_mentions") + + self.parsers = {} + for attr, func in inspect.getmembers(self): + if attr.startswith("parse_"): + self.parsers[attr[6:].upper()] = func + + def _loads(self, value, decode): + if value is None: + return value + + if not decode: + return value.decode("utf-8") + + try: + return orjson.loads(value) + except orjson.JSONDecodeError: + return value.decode("utf-8") + + def _dumps(self, value): + if isinstance(value, (str, int, float)): + return value + return orjson.dumps(value).decode("utf-8") + + async def delete(self, key): + return await self.redis.delete(key) + + async def get(self, keys, decode=True): + results = [] + if isinstance(keys, (list, tuple)): + if len(keys) == 0: + return [] + results.extend([self._loads(x, decode) for x in await self.redis.mget(*keys)]) + else: + results.append(self._loads(await self.redis.get(keys), decode)) + + for index, value in enumerate(results): + if isinstance(value, dict): + value["_key"] = keys[index] if isinstance(keys, (list, tuple)) else keys + + results[index] = value + + if isinstance(keys, (list, tuple)): + return [x for x in results if x is not None] + + return results[0] + + async def expire(self, key, time): + return await self.redis.expire(key, time) + + async def set(self, key, value=None): + if isinstance(key, (list, tuple)): + return await self.redis.mset(*key) + + return await self.redis.set(key, self._dumps(value)) + + async def sadd(self, key, *value): + return await self.redis.sadd(key, *[self._dumps(x) for x in value]) + + async def srem(self, key, *value): + return await self.redis.srem(key, *[self._dumps(x) for x in value]) + + async def smembers(self, key, decode=True): + return [self._loads(x, decode) for x in await self.redis.smembers(key)] + + async def sismember(self, key, value): + return await self.redis.sismember(key, self._dumps(value)) + + async def scard(self, key): + return await self.redis.scard(key) + + async def _members(self, key, key_id=None): + key += "_keys" + + if key_id: + key += f":{key_id}" + + return [x.decode("utf-8") for x in await self.redis.smembers(key)] + + async def _members_get( + self, key, key_id=None, name=None, first=None, second=None, predicate=None + ): + for match in await self._members(key, key_id): + keys = match.split(":") + if name is None or keys[0] == str(name): + if first is None or (len(keys) >= 2 and keys[1] == str(first)): + if second is None or (len(keys) >= 3 and keys[2] == str(second)): + if predicate is None or predicate(match) is True: + return await self.get(match) + + return None + + async def _members_get_all( + self, key, key_id=None, name=None, first=None, second=None, predicate=None + ): + matches = [] + for match in await self._members(key, key_id): + keys = match.split(":") + if name is None or keys[0] == str(name): + if first is None or (len(keys) >= 1 and keys[1] == str(first)): + if second is None or (len(keys) >= 2 and keys[2] == str(second)): + if predicate is None or predicate(match) is True: + matches.append(match) + + return await self.get(matches) + + def _key_first(self, obj): + keys = obj["_key"].split(":") + return int(keys[1]) + + async def _users(self): + user_ids = set([x.split(":")[2] for x in await self._members("member")]) + return [User(state=self, data=x["user"]) for x in await self.get(user_ids)] + + async def _emojis(self): + results = await self._members_get_all("emoji") + emojis = [] + + for result in results: + guild = await self._get_guild(self._key_first(result)) + + if guild: + emojis.append(Emoji(guild=guild, state=self, data=result)) + + return emojis + + async def _guilds(self): + guilds = [Guild(state=self, data=x) for x in await self._members_get_all("guild")] + return [x for x in guilds if not x.unavailable] + + async def _private_channels(self): + return [] + + async def _messages(self): + messages = [] + for result in await self._members_get_all("message"): + channel = await self.get_channel(int(result["channel_id"])) + + if channel: + message = Message(channel=channel, state=self, data=result) + messages.append(message) + + return messages + + def process_chunk_requests(self, guild_id, nonce, members, complete): + return + + def call_handlers(self, key, *args, **kwargs): + try: + func = self.handlers[key] + except KeyError: + pass + else: + func(*args, **kwargs) + + async def call_hooks(self, key, *args, **kwargs): + try: + func = self.hooks[key] + except KeyError: + pass + else: + await func(*args, **kwargs) + + async def user(self): + result = await self.get("bot_user") + if result: + return ClientUser(state=self, data=result) + return None + + def self_id(self): + return self.id + + @property + def intents(self): + return + + @property + def voice_clients(self): + return + + def _get_voice_client(self, guild_id): + return + + def _add_voice_client(self, guild_id, voice): + return + + def _remove_voice_client(self, guild_id): + return + + def _update_references(self, ws): + return + + def store_user(self, data): + return User(state=self, data=data) + + async def get_user(self, user_id): + result = await self._members_get("member", second=user_id) + + if result: + return User(state=self, data=result["user"]) + + return None + + def store_emoji(self, guild, data): + return Emoji(guild=guild, state=self, data=data) + + async def guilds(self): + return await self._guilds() + + async def _get_guild(self, guild_id): + result = await self.get(f"guild:{guild_id}") + + if result: + guild = Guild(state=self, data=result) + if not guild.unavailable: + return guild + + return None + + def _add_guild(self, guild): + return + + def _remove_guild(self, guild): + return + + async def emojis(self): + return await self._emojis() + + async def get_emoji(self, emoji_id): + result = await self._members_get("emoji", second=emoji_id) + + if result: + guild = await self._get_guild(self._key_first(result)) + + if guild: + return Emoji(guild=guild, state=self, data=result) + + return None + + async def private_channels(self): + return await self._private_channels() + + async def _get_private_channel(self, channel_id): + result = await self._get_channel(channel_id) + if result and isinstance(result, DMChannel): + return result + + return None + + async def _get_private_channel_by_user(self, user_id): + return utils.find(lambda x: x.recipient.id == user_id, await self.private_channels()) + + def _add_private_channel(self, channel): + return + + def add_dm_channel(self, data): + return DMChannel(me=self.user, state=self, data=data) + + def _remove_private_channel(self, channel): + return + + async def _get_message(self, msg_id): + result = await self._members_get("message", second=msg_id) + + if result: + channel = await self.get_channel(self._key_first(result)) + + if channel: + result = Message(channel=channel, state=self, data=result) + + return result + + def _add_guild_from_data(self, guild): + return Guild(state=self, data=guild) + + def _guild_needs_chunking(self, guild): + return + + async def _get_guild_channel(self, channel_id): + result = await self._get_channel(channel_id) + if result and not isinstance(result, DMChannel): + return result + + return None + + async def chunker(self, guild_id, query="", limit=0, *, nonce=None): + return + + async def query_members(self, guild, query, limit, user_ids, cache): + return + + async def _delay_ready(self): + try: + while True: + try: + guild = await asyncio.wait_for( + self._ready_state.get(), timeout=self._ready_timeout + ) + except asyncio.TimeoutError: + break + else: + if guild.unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + try: + del self._ready_state + except AttributeError: + pass + except asyncio.CancelledError: + pass + else: + self.call_handlers("ready") + self.dispatch("ready") + finally: + self._ready_task = None + + async def parse_ready(self, data, old): + if self._ready_task is not None: + self._ready_task.cancel() + + self.dispatch("connect") + self._ready_state = asyncio.Queue() + self._ready_task = asyncio.ensure_future(self._delay_ready(), loop=self.loop) + + async def parse_resumed(self, data, old): + self.dispatch("resumed") + + async def parse_message_create(self, data, old): + channel = await self.get_channel(int(data["channel_id"])) + + if not channel and not data.get("guild_id"): + channel = DMChannel(me=await self.user(), state=self, data={"id": data["channel_id"]}) + + if channel: + message = self.create_message(channel=channel, data=data) + self.dispatch("message", message) + + async def parse_message_delete(self, data, old): + raw = RawMessageDeleteEvent(data) + + if old: + channel = await self.get_channel(int(data["channel_id"])) + if channel: + old = self.create_message(channel=channel, data=old) + raw.cached_message = old + self.dispatch("message_delete", old) + + self.dispatch("raw_message_delete", raw) + + async def parse_message_delete_bulk(self, data, old): + raw = RawBulkMessageDeleteEvent(data) + + if old: + messages = [] + for old_message in old: + channel = await self.get_channel(int(old_message["channel_id"])) + if channel: + messages.append(self.create_message(channel=channel, data=old_message)) + + raw.cached_messages = old + self.dispatch("bulk_message_delete", old) + + self.dispatch("raw_bulk_message_delete", raw) + + async def parse_message_update(self, data, old): + raw = RawMessageUpdateEvent(data) + + if old: + channel = await self.get_channel(int(data["channel_id"])) + if channel: + old = self.create_message(channel=channel, data=old) + raw.cached_message = old + new = copy.copy(old) + new._update(data) + self.dispatch("message_edit", old, new) + + self.dispatch("raw_message_edit", raw) + + async def parse_message_reaction_add(self, data, old): + emoji = PartialEmoji.with_state( + self, + id=utils._get_as_snowflake(data["emoji"], "id"), + animated=data["emoji"].get("animated", False), + name=data["emoji"]["name"], + ) + + raw = RawReactionActionEvent(data, emoji, "REACTION_ADD") + + member = data.get("member") + if member: + guild = await self._get_guild(raw.guild_id) + if guild: + raw.member = Member(guild=guild, state=self, data=member) + + self.dispatch("raw_reaction_add", raw) + + message = await self._get_message(raw.message_id) + if message: + reaction = Reaction( + message=message, data=data, emoji=await self._upgrade_partial_emoji(emoji) + ) + user = raw.member or await self._get_reaction_user(message.channel, raw.user_id) + + if user: + self.dispatch("reaction_add", reaction, user) + + async def parse_message_reaction_remove_all(self, data, old): + raw = RawReactionClearEvent(data) + self.dispatch("raw_reaction_clear", raw) + + message = await self._get_message(raw.message_id) + if message: + self.dispatch("reaction_clear", message, None) + + async def parse_message_reaction_remove(self, data, old): + emoji = PartialEmoji.with_state( + self, + id=utils._get_as_snowflake(data["emoji"], "id"), + name=data["emoji"]["name"], + ) + + raw = RawReactionActionEvent(data, emoji, "REACTION_REMOVE") + self.dispatch("raw_reaction_remove", raw) + + message = await self._get_message(raw.message_id) + if message: + reaction = Reaction( + message=message, data=data, emoji=await self._upgrade_partial_emoji(emoji) + ) + user = await self._get_reaction_user(message.channel, raw.user_id) + + if user: + self.dispatch("reaction_remove", reaction, user) + + async def parse_message_reaction_remove_emoji(self, data, old): + emoji = PartialEmoji.with_state( + self, + id=utils._get_as_snowflake(data["emoji"], "id"), + name=data["emoji"]["name"], + ) + + raw = RawReactionClearEmojiEvent(data, emoji) + self.dispatch("raw_reaction_clear_emoji", raw) + + message = await self._get_message(raw.message_id) + if message: + reaction = Reaction( + message=message, data=data, emoji=await self._upgrade_partial_emoji(emoji) + ) + self.dispatch("reaction_clear_emoji", reaction) + + async def parse_presence_update(self, data, old): + guild = await self._get_guild(utils._get_as_snowflake(data, "guild_id")) + + if not guild: + return + + old_member = None + member = await guild.get_member(int(data["user"]["id"])) + if member and old: + old_member = Member._copy(member) + user_update = old_member._presence_update(data=old, user=old["user"]) + + if user_update: + self.dispatch("user_update", user_update[1], user_update[0]) + + self.dispatch("member_update", old_member, member) + + async def parse_user_update(self, data, old): + return + + async def parse_invite_create(self, data, old): + invite = Invite.from_gateway(state=self, data=data) + self.dispatch("invite_create", invite) + + async def parse_invite_delete(self, data, old): + invite = Invite.from_gateway(state=self, data=data) + self.dispatch("invite_delete", invite) + + async def parse_channel_delete(self, data, old): + if old and old["guild_id"]: + guild = await self._get_guild(utils._get_as_snowflake(data, "guild_id")) + if guild: + factory, _ = _channel_factory(old["type"]) + channel = factory(guild=guild, state=self, data=old) + self.dispatch("guild_channel_delete", channel) + elif old: + channel = DMChannel(me=self.user, state=self, data=old) + self.dispatch("private_channel_delete", channel) + + async def parse_channel_update(self, data, old): + channel_type = try_enum(ChannelType, data.get("type")) + if old and channel_type is ChannelType.private: + channel = DMChannel(me=self.user, state=self, data=data) + old_channel = DMChannel(me=self.user, state=self, data=old) + self.dispatch("private_channel_update", old_channel, channel) + elif old: + guild = await self._get_guild(utils._get_as_snowflake(data, "guild_id")) + if guild: + factory, _ = _channel_factory(data["type"]) + channel = factory(guild=guild, state=self, data=data) + old_factory, _ = _channel_factory(old["type"]) + old_channel = old_factory(guild=guild, state=self, data=old) + self.dispatch("guild_channel_update", old_channel, channel) + + async def parse_channel_create(self, data, old): + factory, ch_type = _channel_factory(data["type"]) + if ch_type is ChannelType.private: + channel = DMChannel(me=self.user, data=data, state=self) + self.dispatch("private_channel_create", channel) + else: + guild = await self._get_guild(utils._get_as_snowflake(data, "guild_id")) + if guild: + channel = factory(guild=guild, state=self, data=data) + self.dispatch("guild_channel_create", channel) + + async def parse_channel_pins_update(self, data, old): + channel = await self.get_channel(int(data["channel_id"])) + last_pin = ( + utils.parse_time(data["last_pin_timestamp"]) if data["last_pin_timestamp"] else None + ) + + try: + channel.guild + except AttributeError: + self.dispatch("private_channel_pins_update", channel, last_pin) + else: + self.dispatch("guild_channel_pins_update", channel, last_pin) + + async def parse_channel_recipient_add(self, data, old): + return + + async def parse_channel_recipient_remove(self, data, old): + return + + async def parse_guild_member_add(self, data, old): + guild = await self._get_guild(int(data["guild_id"])) + if guild: + member = Member(guild=guild, data=data, state=self) + self.dispatch("member_join", member) + + async def parse_guild_member_remove(self, data, old): + if old: + guild = await self._get_guild(int(data["guild_id"])) + if guild: + member = Member(guild=guild, data=old, state=self) + self.dispatch("member_remove", member) + + async def parse_guild_member_update(self, data, old): + guild = await self._get_guild(int(data["guild_id"])) + if old and guild: + member = await guild.get_member(int(data["user"]["id"])) + if member: + old_member = Member._copy(member) + old_member._update(old) + user_update = old_member._update_inner_user(data["user"]) + + if user_update: + self.dispatch("user_update", user_update[1], user_update[0]) + + self.dispatch("member_update", old_member, member) + + async def parse_guild_emojis_update(self, data, old): + guild = await self._get_guild(int(data["guild_id"])) + if guild: + before_emojis = None + if old: + before_emojis = [self.store_emoji(guild, x) for x in old] + + after_emojis = tuple(map(lambda x: self.store_emoji(guild, x), data["emojis"])) + self.dispatch("guild_emojis_update", guild, before_emojis, after_emojis) + + def _get_create_guild(self, data): + return self._add_guild_from_data(data) + + async def chunk_guild(self, guild, *, wait=True, cache=None): + return + + async def _chunk_and_dispatch(self, guild, unavailable): + return + + async def parse_guild_create(self, data, old): + unavailable = data.get("unavailable") + + if unavailable is True: + return + + guild = self._get_create_guild(data) + try: + self._ready_state.put_nowait(guild) + except AttributeError: + if unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + + async def parse_guild_sync(self, data, old): + return + + async def parse_guild_update(self, data, old): + guild = await self._get_guild(int(data["id"])) + if guild: + old_guild = None + + if old: + old_guild = copy.copy(guild) + old_guild = old_guild._from_data(old) + + self.dispatch("guild_update", old_guild, guild) + + async def parse_guild_delete(self, data, old): + if old: + old = Guild(state=self, data=old) + if data.get("unavailable", False): + new = Guild(state=self, data=data) + self.dispatch("guild_unavailable", new) + else: + self.dispatch("guild_remove", old) + + async def parse_guild_ban_add(self, data, old): + guild = await self._get_guild(int(data["guild_id"])) + if guild: + user = self.store_user(data["user"]) + member = await guild.get_member(user.id) or user + self.dispatch("member_ban", guild, member) + + async def parse_guild_ban_remove(self, data, old): + guild = await self._get_guild(int(data["guild_id"])) + if guild: + self.dispatch("member_unban", guild, self.store_user(data["user"])) + + async def parse_guild_role_create(self, data, old): + guild = await self._get_guild(int(data["guild_id"])) + if guild: + role = Role(guild=guild, state=self, data=data["role"]) + self.dispatch("guild_role_create", role) + + async def parse_guild_role_delete(self, data, old): + if old: + guild = await self._get_guild(int(data["guild_id"])) + if guild: + role = Role(guild=guild, state=self, data=old) + self.dispatch("guild_role_delete", role) + + async def parse_guild_role_update(self, data, old): + if old: + guild = await self._get_guild(int(data["guild_id"])) + if guild: + role = Role(guild=guild, state=self, data=data["role"]) + old_role = Role(guild=guild, state=self, data=old) + self.dispatch("guild_role_update", old_role, role) + + async def parse_guild_members_chunk(self, data, old): + return + + async def parse_guild_integrations_update(self, data, old): + guild = await self._get_guild(int(data["guild_id"])) + if guild: + self.dispatch("guild_integrations_update", guild) + + async def parse_webhooks_update(self, data, old): + channel = await self._get_guild(int(data["channel_id"])) + if channel: + self.dispatch("webhooks_update", channel) + + async def parse_voice_state_update(self, data, old): + guild = await self._get_guild(utils._get_as_snowflake(data, "guild_id")) + if guild: + member = await guild.get_member(int(data["user_id"])) + if member: + channel = await self.get_channel(utils._get_as_snowflake(data, "channel_id")) + if channel: + before = None + after = VoiceState(data=data, channel=channel) + old_channel = await self.get_channel(old["channel_id"]) + + if old and old_channel: + before = VoiceState(data=data, channel=old_channel) + + self.dispatch("voice_state_update", member, before, after) + + def parse_voice_server_update(self, data, old): + return + + async def parse_typing_start(self, data, old): + channel = await self._get_guild_channel(int(data["channel_id"])) + if channel: + member = None + + if isinstance(channel, DMChannel): + member = channel.recipient + elif isinstance(channel, TextChannel): + guild = await self._get_guild(int(data["guild_id"])) + if guild: + member = await guild.get_member(utils._get_as_snowflake(data, "user_id")) + + if member: + self.dispatch( + "typing", + channel, + member, + datetime.datetime.utcfromtimestamp(data.get("timestamp")), + ) + + async def parse_relationship_add(self, data, old): + return + + async def parse_relationship_remove(self, data, old): + return + + async def _get_reaction_user(self, channel, user_id): + if isinstance(channel, TextChannel): + return await channel.guild.get_member(user_id) + return await self.get_user(user_id) + + async def get_reaction_emoji(self, data): + emoji_id = utils._get_as_snowflake(data, "id") + + if not emoji_id: + return data["name"] + + return await self.get_emoji(emoji_id) + + async def _upgrade_partial_emoji(self, emoji): + if not emoji.id: + return emoji.name + + return await self.get_emoji(emoji.id) + + async def _get_channel(self, channel_id): + result = await self.get(f"channel:{channel_id}") + + if result: + if result.get("guild_id"): + factory, _ = _channel_factory(result["type"]) + guild = await self._get_guild(result["guild_id"]) + + if guild: + return factory(guild=guild, state=self, data=result) + else: + return DMChannel(me=self.user, state=self, data=result) + + return None + + async def get_channel(self, channel_id): + if not channel_id: + return None + + return await self._get_channel(channel_id) + + def create_message(self, *, channel, data): + message = Message(state=self, channel=channel, data=data) + return message diff --git a/classes/bot.py b/classes/bot.py index 990e8045..59b9eec3 100644 --- a/classes/bot.py +++ b/classes/bot.py @@ -1,4 +1,5 @@ import asyncio +import discord import logging import re import sys @@ -6,7 +7,7 @@ import aio_pika import aiohttp -import aioredis +from redis import asyncio as aioredis import asyncpg import orjson @@ -22,11 +23,13 @@ from utils.config import Config from utils.prometheus import Prometheus -log = logging.getLogger(__name__) + +logger = logging.getLogger(__name__) + class ModMail(commands.AutoShardedBot): - def __init__(self, command_prefix=None, **kwargs): + def __init__(self, command_prefix=None, intents: discord.Intents=None, **kwargs): self.command_prefix = command_prefix self.extra_events = {} self._BotBase__cogs = {} @@ -43,10 +46,11 @@ def __init__(self, command_prefix=None, **kwargs): self.case_insensitive = True self.all_commands = _CaseInsensitiveDict() if self.case_insensitive else {} self.strip_after_prefix = False + self.ws = None self.loop = asyncio.get_event_loop() - self.http = HTTPClient(None, loop=self.loop) + self.http = HTTPClient(loop=self.loop) self._handlers = {"ready": self._handle_ready} self._hooks = {} @@ -61,7 +65,7 @@ def __init__(self, command_prefix=None, **kwargs): self._amqp_queue = None self.config = Config() - self.session = aiohttp.ClientSession(loop=self.loop) + self.session = aiohttp.ClientSession(loop=asyncio.get_event_loop()) self.http_uri = f"http://{self.config.BOT_API_HOST}:{self.config.BOT_API_PORT}" self.id = kwargs.get("bot_id") self.cluster = kwargs.get("cluster_id") @@ -90,58 +94,19 @@ def __init__(self, command_prefix=None, **kwargs): "premium", "snippet", ] - - @property - def state(self): - return self._connection + print(kwargs) + super().__init__(command_prefix=command_prefix, intents=intents, **kwargs) @property def user(self): return tools.create_fake_user(self.id) + + @property + def state(self): + return self._connection async def real_user(self): - return await self._connection.user() - - async def users(self): - return await self._connection._users() - - async def guilds(self): - return await self._connection.guilds() - - async def emojis(self): - return await self._connection.emojis() - - async def cached_messages(self): - return await self._connection._messages() - - async def private_channels(self): - return await self._connection.private_channels() - - async def shard_count(self): - return int(await self._connection.get("gateway_shards")) - - async def started(self): - return parse_time(str(await self._connection.get("gateway_started")).split(".")[0]) - - async def statuses(self): - return [Status(x) for x in await self._connection.get("gateway_statuses")] - - async def sessions(self): - return { - int(x): Session(y) for x, y in (await self._connection.get("gateway_sessions")).items() - } - - async def get_channel(self, channel_id): - return await self._connection.get_channel(channel_id) - - async def get_guild(self, guild_id): - return await self._connection._get_guild(guild_id) - - async def get_user(self, user_id): - return await self._connection.get_user(user_id) - - async def get_emoji(self, emoji_id): - return await self._connection.get_emoji(emoji_id) + return await self._connection.get_me() async def get_all_channels(self): pass @@ -150,6 +115,7 @@ async def get_all_members(self): pass async def receive_message(self, msg): + logger.info(msg) self.ws._dispatch("socket_raw_receive", msg) msg = orjson.loads(msg) self.ws._dispatch("socket_response", msg) @@ -165,7 +131,7 @@ async def receive_message(self, msg): try: func = self.ws._discord_parsers[event] except KeyError: - log.debug(f"Unknown event {event}.") + logger.debug(f"Unknown event {event}.") return if event not in self._enabled_events: @@ -195,7 +161,7 @@ async def on_http_request_end(self, _session, trace_config_ctx, params): elapsed = asyncio.get_event_loop().time() - trace_config_ctx.start if elapsed > 1: - log.warning(f"{params.method} {params.url} took {round(elapsed, 2)} seconds") + logger.warning(f"{params.method} {params.url} took {round(elapsed, 2)} seconds") route = str(params.url) route = re.sub(r"https:\/\/[a-z\.]+\/api\/v[0-9]+", "", route) @@ -214,15 +180,16 @@ async def on_http_request_end(self, _session, trace_config_ctx, params): ) async def start(self, worker=True): + logger.info("In Setup") trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(self.on_http_request_start) trace_config.on_request_end.append(self.on_http_request_end) - self.http._HTTPClient__session = aiohttp.ClientSession( - connector=self.http.connector, + self.http.__session = aiohttp.ClientSession( ws_response_class=DiscordClientWebSocketResponse, trace_configs=[trace_config], ) - self.http._token(self.config.BOT_TOKEN, bot=True) + self.http.token = self.config.BOT_TOKEN + self.pool = await asyncpg.create_pool( database=self.config.POSTGRES_DATABASE, @@ -233,14 +200,10 @@ async def start(self, worker=True): max_size=10, command_timeout=60, ) - - self._redis = await aioredis.create_redis_pool( - (self.config.REDIS_HOST, int(self.config.REDIS_PORT)), - password=self.config.REDIS_PASSWORD, - minsize=5, - maxsize=10, - loop=self.loop, + pool = aioredis.ConnectionPool.from_url( + f'redis://{self.config.REDIS_HOST}:{self.config.REDIS_PORT}' ) + self._redis = aioredis.Redis(connection_pool=pool) if worker: self._amqp = await aio_pika.connect_robust( @@ -251,20 +214,32 @@ async def start(self, worker=True): ) self._amqp_channel = await self._amqp.channel() self._amqp_queue = await self._amqp_channel.get_queue("gateway.recv") + + logger.info(self._cogs) + self.prom = Prometheus(self) - await self.prom.start() - + # await self.prom.start() # TODO: Fix Prometheus + while True: + try: + shards = int(await self._redis.get("gateway_shards")) + except TypeError: + await asyncio.sleep(2) + continue + break + logger.debug(f"Shards: {shards}") + self._connection = State( id=self.id, dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, - loop=self.loop, redis=self._redis, - shard_count=int(await self._redis.get("gateway_shards")), + shard_count=shards, + intents=self.intents, ) + self._connection.id = (await self._connection.get_me()).id self._connection._get_client = lambda: self self.ws = DiscordWebSocket(socket=None, loop=self.loop) @@ -273,19 +248,34 @@ async def start(self, worker=True): self.ws._discord_parsers = self._connection.parsers self.ws._dispatch = self.dispatch self.ws.call_hooks = self._connection.call_hooks + + logger.info('right here') if not worker: + # await self.login(self.config.BOT_TOKEN) + # await self.connect(reconnect=True) return - + + for extension in self._cogs: try: - self.load_extension("cogs." + extension) + await self.load_extension("cogs." + extension) + logger.info(f"Loaded extension {extension}.") except Exception: - log.error(f"Failed to load extension {extension}.", file=sys.stderr) - log.error(traceback.print_exc()) - + logger.error(f"Failed to load extension {extension}.") + logger.error(traceback.print_exc()) + async with self._amqp_queue.iterator() as queue_iter: async for message in queue_iter: + logger.info(message) async with message.process(ignore_processed=True): await self.receive_message(message.body) await message.ack() + + + + + + + + diff --git a/classes/channel.py b/classes/channel.py index d1586558..1df1ad2e 100644 --- a/classes/channel.py +++ b/classes/channel.py @@ -1,12 +1,12 @@ import logging from discord import channel, utils -from discord.channel import CategoryChannel, GroupChannel, StageChannel, StoreChannel, VoiceChannel +from discord.channel import CategoryChannel, GroupChannel, StageChannel, VoiceChannel from discord.enums import ChannelType, try_enum from discord.permissions import Permissions from classes.embed import Embed -from classes.invite import Invite +from discord.invite import Invite log = logging.getLogger(__name__) @@ -121,8 +121,6 @@ def _channel_factory(channel_type): return GroupChannel, value elif value is ChannelType.news: return TextChannel, value - elif value is ChannelType.store: - return StoreChannel, value elif value is ChannelType.stage_voice: return StageChannel, value else: diff --git a/classes/embed.py b/classes/embed.py index a9fe7fd5..f39190ca 100644 --- a/classes/embed.py +++ b/classes/embed.py @@ -21,16 +21,16 @@ def __init__(self, *args, **kwargs): super().__init__(**kwargs) - def set_author(self, name=discord.Embed.Empty, icon_url=discord.Embed.Empty, **kwargs): + def set_author(self, name=None, icon_url=None, **kwargs): super().set_author(name=name, icon_url=icon_url, **kwargs) - def set_footer(self, text=discord.Embed.Empty, icon_url=discord.Embed.Empty): + def set_footer(self, text=None, icon_url=None): super().set_footer(text=text, icon_url=icon_url) - def set_thumbnail(self, url=discord.Embed.Empty): + def set_thumbnail(self, url=None): super().set_thumbnail(url=url) - def add_field(self, name=discord.Embed.Empty, value=discord.Embed.Empty, inline=True): + def add_field(self, name=None, value=None, inline=True): super().add_field(name=name, value=value, inline=inline) diff --git a/classes/guild.py b/classes/guild.py index d80b0c0e..3aec64f1 100644 --- a/classes/guild.py +++ b/classes/guild.py @@ -8,7 +8,6 @@ ContentFilter, NotificationLevel, VerificationLevel, - VoiceRegion, try_enum, ) from discord.member import VoiceState @@ -23,8 +22,11 @@ class Guild(guild.Guild): def __init__(self, *, data, state): + super().__init__(state=state, data=data) + return self._state = state self._from_data(data) + def _add_channel(self, channel): return @@ -47,47 +49,46 @@ def _add_role(self, role): def _remove_role(self, role_id): return - def _from_data(self, guild): - member_count = guild.get("member_count", None) - if member_count is not None: - self._member_count = member_count - else: - self._member_count = 0 - - self.name = guild.get("name") - self.region = try_enum(VoiceRegion, guild.get("region")) - self.verification_level = try_enum(VerificationLevel, guild.get("verification_level")) - self.default_notifications = try_enum( - NotificationLevel, guild.get("default_message_notifications") - ) - self.explicit_content_filter = try_enum( - ContentFilter, guild.get("explicit_content_filter", 0) - ) - self.afk_timeout = guild.get("afk_timeout") - self.icon = guild.get("icon") - self.banner = guild.get("banner") - self.unavailable = guild.get("unavailable", False) - self.id = int(guild["id"]) - self.mfa_level = guild.get("mfa_level") - self.features = guild.get("features", []) - self.splash = guild.get("splash") - self._system_channel_id = utils._get_as_snowflake(guild, "system_channel_id") - self.description = guild.get("description") - self.max_presences = guild.get("max_presences") - self.max_members = guild.get("max_members") - self.max_video_channel_users = guild.get("max_video_channel_users") - self.premium_tier = guild.get("premium_tier", 0) - self.premium_subscription_count = guild.get("premium_subscription_count") or 0 - self._system_channel_flags = guild.get("system_channel_flags", 0) - self.preferred_locale = guild.get("preferred_locale") - self.discovery_splash = guild.get("discovery_splash") - self._rules_channel_id = utils._get_as_snowflake(guild, "rules_channel_id") - self._public_updates_channel_id = utils._get_as_snowflake( - guild, "public_updates_channel_id" - ) - self._large = None if member_count is None else self._member_count >= 250 - self.owner_id = utils._get_as_snowflake(guild, "owner_id") - self._afk_channel_id = utils._get_as_snowflake(guild, "afk_channel_id") + # def _from_data(self, guild): + # member_count = guild.get("member_count", None) + # if member_count is not None: + # self._member_count = member_count + # else: + # self._member_count = 0 + + # self.name = guild.get("name") + # self.verification_level = try_enum(VerificationLevel, guild.get("verification_level")) + # self.default_notifications = try_enum( + # NotificationLevel, guild.get("default_message_notifications") + # ) + # self.explicit_content_filter = try_enum( + # ContentFilter, guild.get("explicit_content_filter", 0) + # ) + # self.afk_timeout = guild.get("afk_timeout") + # self._icon = guild.get("icon") + # self._banner = guild.get("banner") + # self.unavailable = guild.get("unavailable", False) + # self.id = int(guild["id"]) + # self.mfa_level = guild.get("mfa_level") + # self.features = guild.get("features", []) + # self.splash = guild.get("splash") + # self._system_channel_id = utils._get_as_snowflake(guild, "system_channel_id") + # self.description = guild.get("description") + # self.max_presences = guild.get("max_presences") + # self.max_members = guild.get("max_members") + # self.max_video_channel_users = guild.get("max_video_channel_users") + # self.premium_tier = guild.get("premium_tier", 0) + # self.premium_subscription_count = guild.get("premium_subscription_count") or 0 + # self._system_channel_flags = guild.get("system_channel_flags", 0) + # self.preferred_locale = guild.get("preferred_locale") + # self.discovery_splash = guild.get("discovery_splash") + # self._rules_channel_id = utils._get_as_snowflake(guild, "rules_channel_id") + # self._public_updates_channel_id = utils._get_as_snowflake( + # guild, "public_updates_channel_id" + # ) + # self._large = None if member_count is None else self._member_count >= 250 + # self.owner_id = utils._get_as_snowflake(guild, "owner_id") + # self._afk_channel_id = utils._get_as_snowflake(guild, "afk_channel_id") async def create_text_channel( self, name, *, overwrites=None, category=None, reason=None, **options diff --git a/classes/message.py b/classes/message.py index 40d57ff1..aa1e7ebc 100644 --- a/classes/message.py +++ b/classes/message.py @@ -4,13 +4,14 @@ from discord import message, utils from discord.enums import MessageType, try_enum from discord.flags import MessageFlags +from discord.guild import Guild from discord.message import Attachment, MessageReference, flatten_handlers from discord.reaction import Reaction from classes.embed import Embed from classes.member import Member -log = logging.getLogger(__name__) +logger = logging.getLogger(__name__) @flatten_handlers @@ -33,6 +34,7 @@ def __init__(self, *, state, channel, data): self.tts = data["tts"] self.content = data["content"] self.nonce = data.get("nonce") + self.guild = utils. ref = copy.copy(data.get("message_reference")) self.reference = MessageReference.with_state(state, ref) if ref is not None else None @@ -47,6 +49,7 @@ def __init__(self, *, state, channel, data): try: author._update_from_message(self._data["member"]) except AttributeError: + logger.info(self) author = Member._from_message(message=self, data=self._data["member"]) self._member = author except KeyError: @@ -57,6 +60,8 @@ def __init__(self, *, state, channel, data): getattr(self, f"_handle_{handler}")(data[handler]) except KeyError: continue + + super().__init__(state=state, channel=channel, data=data) @property def author(self): diff --git a/classes/state.py b/classes/state.py index 5d6a4b3a..f22072f1 100644 --- a/classes/state.py +++ b/classes/state.py @@ -1,12 +1,44 @@ import asyncio import copy import datetime +import discord import inspect import logging - +from redis import asyncio as aioredis import orjson - +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Deque, + Dict, + Generic, + List, + Literal, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + overload, +) from discord import utils +from discord._types import ClientT +from discord.channel import * +from discord.channel import _channel_factory +from discord.emoji import Emoji +from discord.enums import ChannelType, Status, try_enum +from discord.guild import Guild +from discord.invite import Invite +from discord.member import Member +from discord.message import Message +from discord.partial_emoji import PartialEmoji +from discord.raw_models import * +from discord.role import Role +from discord.user import ClientUser, User +from discord import utils, state +from discord.channel import * from discord.emoji import Emoji from discord.enums import ChannelType, try_enum from discord.invite import Invite @@ -19,46 +51,45 @@ RawReactionActionEvent, RawReactionClearEmojiEvent, RawReactionClearEvent, + RawMemberRemoveEvent ) from discord.reaction import Reaction from discord.role import Role from discord.user import ClientUser, User -from classes.channel import DMChannel, TextChannel, _channel_factory +from classes.channel import DMChannel from classes.guild import Guild from classes.member import Member from classes.message import Message -log = logging.getLogger(__name__) +# if TYPE_CHECKING: +# from discord.types import gateway as gw +# from discord.types.user import PartialUser as PartialUserPayload +# from discord.types.user import User as UserPayload +# T = TypeVar('T') +# Channel = Union[GuildChannel, PrivateChannel, PartialMessageable] +logger = logging.getLogger(__name__) -class State: + + +# Overriding AutoShardedConnectionState +# Replace built in cache with redis cache +# Overrides methods related to: +# Users, Guilds, Channels, Messages + +class State(state.AutoShardedConnectionState): def __init__( - self, *, dispatch, handlers, hooks, http, loop, redis=None, shard_count=None, id, **options + self, *, dispatch, handlers, hooks, http, redis: aioredis.Redis = None, shard_count=None, id, **options ): - self.dispatch = dispatch - self.handlers = handlers - self.hooks = hooks - self.http = http - self.loop = loop + super().__init__(dispatch=dispatch, handlers=handlers, hooks=hooks, http=http, **options) + + # self.loop = loop self.redis = redis self.shard_count = shard_count - self.id = id - - self._ready_task = None - self._ready_state = None - self._ready_timeout = options.get("guild_ready_timeout", 2.0) - - self._voice_clients = {} - self._private_channels_by_user = {} - - self.allowed_mentions = options.get("allowed_mentions") - - self.parsers = {} - for attr, func in inspect.getmembers(self): - if attr.startswith("parse_"): - self.parsers[attr[6:].upper()] = func + + ######### REDIS METHODS ######### def _loads(self, value, decode): if value is None: return value @@ -91,12 +122,11 @@ async def get(self, keys, decode=True): for index, value in enumerate(results): if isinstance(value, dict): value["_key"] = keys[index] if isinstance(keys, (list, tuple)) else keys - results[index] = value if isinstance(keys, (list, tuple)): return [x for x in results if x is not None] - + return results[0] async def expire(self, key, time): @@ -104,7 +134,8 @@ async def expire(self, key, time): async def set(self, key, value=None): if isinstance(key, (list, tuple)): - return await self.redis.mset(*key) + kvd = dict(list(zip(key[::2], key[1::2]))) + return await self.redis.mset(kvd) return await self.redis.set(key, self._dumps(value)) @@ -122,33 +153,20 @@ async def sismember(self, key, value): async def scard(self, key): return await self.redis.scard(key) - - async def _members(self, key, key_id=None): + + def _members(self, key, key_id=None): key += "_keys" if key_id: key += f":{key_id}" - return [x.decode("utf-8") for x in await self.redis.smembers(key)] - - async def _members_get( - self, key, key_id=None, name=None, first=None, second=None, predicate=None - ): - for match in await self._members(key, key_id): - keys = match.split(":") - if name is None or keys[0] == str(name): - if first is None or (len(keys) >= 2 and keys[1] == str(first)): - if second is None or (len(keys) >= 3 and keys[2] == str(second)): - if predicate is None or predicate(match) is True: - return await self.get(match) - - return None - - async def _members_get_all( + return [x.decode("utf-8") for x in self.redis.smembers(key)] + + def _members_get_all( self, key, key_id=None, name=None, first=None, second=None, predicate=None ): matches = [] - for match in await self._members(key, key_id): + for match in self._members(key, key_id): keys = match.split(":") if name is None or keys[0] == str(name): if first is None or (len(keys) >= 1 and keys[1] == str(first)): @@ -156,12 +174,10 @@ async def _members_get_all( if predicate is None or predicate(match) is True: matches.append(match) - return await self.get(matches) - - def _key_first(self, obj): - keys = obj["_key"].split(":") - return int(keys[1]) - + return self.get(matches) + + ######### END REDIS METHODS ######### + async def _users(self): user_ids = set([x.split(":")[2] for x in await self._members("member")]) return [User(state=self, data=x["user"]) for x in await self.get(user_ids)] @@ -215,10 +231,11 @@ async def call_hooks(self, key, *args, **kwargs): else: await func(*args, **kwargs) - async def user(self): + async def get_me(self): result = await self.get("bot_user") if result: return ClientUser(state=self, data=result) + print("No bot_user in redis") return None def self_id(self): @@ -226,7 +243,7 @@ def self_id(self): @property def intents(self): - return + return discord.Intents.all() @property def voice_clients(self): @@ -348,7 +365,7 @@ async def _delay_ready(self): while True: try: guild = await asyncio.wait_for( - self._ready_state.get(), timeout=self._ready_timeout + await self._ready_state.get(), timeout=self._ready_timeout ) except asyncio.TimeoutError: break @@ -375,7 +392,7 @@ async def parse_ready(self, data, old): self.dispatch("connect") self._ready_state = asyncio.Queue() - self._ready_task = asyncio.ensure_future(self._delay_ready(), loop=self.loop) + self._ready_task = asyncio.ensure_future(await self._delay_ready(), loop=self.loop) async def parse_resumed(self, data, old): self.dispatch("resumed") @@ -384,7 +401,7 @@ async def parse_message_create(self, data, old): channel = await self.get_channel(int(data["channel_id"])) if not channel and not data.get("guild_id"): - channel = DMChannel(me=await self.user(), state=self, data={"id": data["channel_id"]}) + channel = DMChannel(me=await self.get_me(), state=self, data={"id": data["channel_id"]}) if channel: message = self.create_message(channel=channel, data=data) @@ -805,4 +822,5 @@ async def get_channel(self, channel_id): def create_message(self, *, channel, data): message = Message(state=self, channel=channel, data=data) - return message + # logger.info(message) + return message \ No newline at end of file diff --git a/cogs/admin.py b/cogs/admin.py index 1c6be4f4..4dbc97cd 100644 --- a/cogs/admin.py +++ b/cogs/admin.py @@ -130,5 +130,5 @@ async def restart(self, ctx): await self.bot.session.post(f"{self.bot.http_uri}/restart") -def setup(bot): - bot.add_cog(Admin(bot)) +async def setup(bot): + await bot.add_cog(Admin(bot)) diff --git a/cogs/configuration.py b/cogs/configuration.py index 9d226efe..49de3a45 100644 --- a/cogs/configuration.py +++ b/cogs/configuration.py @@ -470,5 +470,5 @@ async def viewconfig(self, ctx): await ctx.send(embed) -def setup(bot): - bot.add_cog(Configuration(bot)) +async def setup(bot): + await bot.add_cog(Configuration(bot)) diff --git a/cogs/core.py b/cogs/core.py index a46986f2..e832478d 100644 --- a/cogs/core.py +++ b/cogs/core.py @@ -346,5 +346,5 @@ async def viewblacklist(self, ctx): await tools.create_paginator(self.bot, ctx, all_pages) -def setup(bot): - bot.add_cog(Core(bot)) +async def setup(bot): + await bot.add_cog(Core(bot)) diff --git a/cogs/direct_message.py b/cogs/direct_message.py index a53193df..f98aa913 100644 --- a/cogs/direct_message.py +++ b/cogs/direct_message.py @@ -8,7 +8,7 @@ from discord.ext import commands from classes.embed import Embed, ErrorEmbed -from classes.message import Message +from discord.message import Message from utils import tools from utils.converters import GuildConverter @@ -454,5 +454,5 @@ async def confirmation(self, ctx): await ctx.send(Embed("Confirmation messages are enabled.")) -def setup(bot): - bot.add_cog(DirectMessageEvents(bot)) +async def setup(bot): + await bot.add_cog(DirectMessageEvents(bot)) diff --git a/cogs/error_handler.py b/cogs/error_handler.py index 11319bed..ff3e66a0 100644 --- a/cogs/error_handler.py +++ b/cogs/error_handler.py @@ -93,5 +93,5 @@ async def _on_command_error(self, ctx, error, bypass=False): pass -def setup(bot): - bot.add_cog(ErrorHandler(bot)) +async def setup(bot): + await bot.add_cog(ErrorHandler(bot)) diff --git a/cogs/events.py b/cogs/events.py index b1a2f824..0c756ebe 100644 --- a/cogs/events.py +++ b/cogs/events.py @@ -5,7 +5,7 @@ from discord.ext import commands -from classes.context import Context +from discord.ext.commands import Context from classes.embed import Embed, ErrorEmbed from utils import tools @@ -113,5 +113,5 @@ async def on_message(self, message): await self.bot.invoke(ctx) -def setup(bot): - bot.add_cog(Events(bot)) +async def setup(bot): + await bot.add_cog(Events(bot)) diff --git a/cogs/general.py b/cogs/general.py index 7f157222..7fc6fae1 100644 --- a/cogs/general.py +++ b/cogs/general.py @@ -24,7 +24,7 @@ def __init__(self, bot): usage="help [command]", aliases=["h", "commands"], ) - async def help(self, ctx, *, command: str = None): + async def help2(self, ctx, *, command: str = None): if command: command = self.bot.get_command(command.lower()) if not command: @@ -180,5 +180,5 @@ async def source(self, ctx): await ctx.send(Embed("GitHub Repository", "https://github.com/chamburr/modmail")) -def setup(bot): - bot.add_cog(General(bot)) +async def setup(bot): + await bot.add_cog(General(bot)) diff --git a/cogs/miscellaneous.py b/cogs/miscellaneous.py index 94127437..57efe8d6 100644 --- a/cogs/miscellaneous.py +++ b/cogs/miscellaneous.py @@ -99,5 +99,5 @@ async def serverinfo(self, ctx): await ctx.send(embed) -def setup(bot): - bot.add_cog(Miscellaneous(bot)) +async def setup(bot): + await bot.add_cog(Miscellaneous(bot)) diff --git a/cogs/modmail_channel.py b/cogs/modmail_channel.py index bc31b862..7dd156c8 100644 --- a/cogs/modmail_channel.py +++ b/cogs/modmail_channel.py @@ -87,12 +87,13 @@ async def send_mail_mod(self, message, prefix, anon=False, snippet=False): try: dm_message = await dm_channel.send(embed, files=files) - except discord.Forbidden: + except discord.Forbidden as e: await message.channel.send( ErrorEmbed( "The message could not be sent. The user might have disabled Direct Messages." ) ) + log.log(log.error, e) return embed.title = "Message Sent" @@ -118,5 +119,5 @@ async def send_mail_mod(self, message, prefix, anon=False, snippet=False): pass -def setup(bot): - bot.add_cog(ModMailEvents(bot)) +async def setup(bot): + await bot.add_cog(ModMailEvents(bot)) diff --git a/cogs/owner.py b/cogs/owner.py index 5cbd2224..4a982339 100644 --- a/cogs/owner.py +++ b/cogs/owner.py @@ -170,5 +170,5 @@ async def unbanserver(self, ctx, *, guild: int): await ctx.send(Embed("Successfully unbanned that server from the bot.")) -def setup(bot): - bot.add_cog(Owner(bot)) +async def setup(bot): + await bot.add_cog(Owner(bot)) diff --git a/cogs/premium.py b/cogs/premium.py index 8be51958..d57a5ae4 100644 --- a/cogs/premium.py +++ b/cogs/premium.py @@ -140,5 +140,5 @@ async def premiumremove(self, ctx, *, guild: int): await ctx.send(Embed("That server no longer has premium.")) -def setup(bot): - bot.add_cog(Premium(bot)) +async def setup(bot): + await bot.add_cog(Premium(bot)) diff --git a/cogs/snippet.py b/cogs/snippet.py index d920e138..37c4d03f 100644 --- a/cogs/snippet.py +++ b/cogs/snippet.py @@ -180,5 +180,5 @@ async def viewsnippet(self, ctx, *, name: str = None): await tools.create_paginator(self.bot, ctx, all_pages) -def setup(bot): - bot.add_cog(Snippet(bot)) +async def setup(bot): + await bot.add_cog(Snippet(bot)) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index f94094c4..75c13268 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -62,7 +62,7 @@ services: - DEFAULT_PREFIX=${DEFAULT_PREFIX} - DEFAULT_SERVER=${DEFAULT_SERVER} - BOT_CLUSTERS=1 - - MAIN_SERVER= + - MAIN_SERVER=538144211866746883 - OWNER_USERS=${OWNER_USERS} - ADMIN_USERS=${ADMIN_USERS} - PREMIUM1_ROLE=0 @@ -100,7 +100,7 @@ services: - RUST_LOG=info - BOT_TOKEN=${BOT_TOKEN} - SHARDS_START=0 - - SHARDS_END=1 + - SHARDS_END=0 - SHARDS_TOTAL=1 - SHARDS_CONCURRENCY=1 - SHARDS_WAIT=5 @@ -149,7 +149,10 @@ services: restart: unless-stopped rabbitmq: - image: rabbitmq:3-alpine + image: rabbitmq:3-management-alpine + ports: + - "5672:5672" + - "15672:15672" restart: unless-stopped volumes: diff --git a/main.py b/main.py index 83659364..63f31987 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,12 @@ import asyncio import json +import logging import os import signal import sys import time -from datetime import datetime +import datetime from pathlib import Path import aiohttp @@ -16,13 +17,17 @@ from classes.bot import ModMail from classes.embed import ErrorEmbed -from classes.message import Message +from discord.message import Message from utils import tools from utils.config import Config + VERSION = "3.3.2" +logger = logging.getLogger() + + class Instance: def __init__(self, instance_id, loop, main): self.id = instance_id @@ -46,34 +51,41 @@ async def read_stream(self, stream): if line: line = line.decode("utf-8")[:-1] - print(f"[Cluster {self.id}] {line}") + logger.info(f"[Cluster Report {self.id}] {line}") else: break async def start(self): if self.is_active: - print(f"[Cluster {self.id}] Already active.") + logger.info(f"[Cluster {self.id}] Already active.") return - + + logger.info(f"{sys.executable} \"{Path.cwd() / 'worker.py'}\" {self.id} {config.BOT_CLUSTERS} " + f"{self.main.bot.id} {VERSION}") + self._process = await asyncio.create_subprocess_shell( f"{sys.executable} \"{Path.cwd() / 'worker.py'}\" {self.id} {config.BOT_CLUSTERS} " f"{self.main.bot.id} {VERSION}", stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - preexec_fn=os.setsid, - limit=1024 * 256, ) + + # self._process = await asyncio.create_subprocess_shell( + # f"/usr/bin/local/python3 /app/worker.py {self.id} {config.BOT_CLUSTERS} {self.main.bot.id} {VERSION}", + # stdin=asyncio.subprocess.DEVNULL, + # stdout=asyncio.subprocess.PIPE, + # stderr=asyncio.subprocess.PIPE, + # ) self.status = "running" - print(f"[Cluster {self.id}] The cluster is starting.") + logger.info(f"[Cluster {self.id}] The cluster is starting.") stdout = self.loop.create_task(self.read_stream(self._process.stdout)) stderr = self.loop.create_task(self.read_stream(self._process.stderr)) await asyncio.wait([stdout, stderr]) - return self def kill(self): @@ -99,7 +111,7 @@ async def premium_updater(self): async with self.bot.pool.acquire() as conn: premium = await conn.fetch( "SELECT identifier, guild FROM premium WHERE expiry IS NOT NULL AND expiry<$1", - int(datetime.utcnow().timestamp() * 1000), + int(datetime.datetime.now(datetime.UTC).timestamp() * 1000), ) for row in premium: @@ -228,15 +240,15 @@ def __init__(self, loop): def dead_process_handler(self, result): instance = result.result() - print( + logger.info( f"[Cluster {instance.id}] The cluster exited with code {instance._process.returncode}." ) if instance._process.returncode in [0, -15]: - print(f"[Cluster {instance.id}] The cluster stopped gracefully.") + logger.info(f"[Cluster {instance.id}] The cluster stopped gracefully.") return - print(f"[Cluster {instance.id}] The cluster is restarting.") + logger.info(f"[Cluster {instance.id}] The cluster is restarting.") instance.loop.create_task(instance.start()) async def user_select_handler(self, body): @@ -283,13 +295,13 @@ def write_targets(self): json.dump(data, file, indent=2) async def launch(self): - print(f"[Cluster Manager] Starting a total of {config.BOT_CLUSTERS} clusters.") + logger.info(f"[Cluster Manager] Starting a total of {config.BOT_CLUSTERS} clusters.") - self.bot = ModMail(cluster_id=0, cluster_count=int(config.BOT_CLUSTERS)) + self.bot = ModMail(intents=discord.Intents.all(), cluster_id=0, cluster_count=int(config.BOT_CLUSTERS)) await self.bot.start(worker=False) - self.bot.id = (await self.bot.real_user()).id self.bot.state.id = self.bot.id + for i in range(int(config.BOT_CLUSTERS)): self.instances.append(Instance(i + 1, loop=self.loop, main=self)) @@ -307,13 +319,14 @@ async def launch(self): config = Config().load() - +discord.utils.setup_logging(level=logging.INFO) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) main = Main(loop=loop) loop.create_task(main.launch()) try: + logger.info("Entering forever loop") loop.run_forever() except KeyboardInterrupt: @@ -335,3 +348,4 @@ def shutdown_handler(_loop, context): finally: loop.run_until_complete(loop.shutdown_asyncgens()) loop.close() + diff --git a/requirements.txt b/requirements.txt index c60f8118..d2ff8f7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,27 @@ -aio-pika==9.4.0 -aiohttp==3.7.4.post0 +aio-pika==9.3.0 aioprometheus==23.3.0 -aioredis==1.3.1 -asyncpg==0.29.0 -dateparser==1.2.0 -orjson==3.9.13 -psutil==5.9.8 -python-dotenv==1.0.1 -subprocess32==3.5.4 +aiormq==6.7.7 +aiosignal==1.3.1 +async-timeout==4.0.3 +asyncpg==0.28.0 +attrs==23.1.0 +charset-normalizer==3.3.1 +dateparser==1.1.8 +discord.py==2.3.2 +frozenlist==1.4.0 +idna==3.4 +multidict==6.0.4 +orjson==3.9.9 +pamqp==3.2.1 +psutil==5.9.6 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +pytz==2023.3.post1 +quantile-python==1.1 +redis==5.0.1 +regex==2023.10.3 +six==1.16.0 +typing_extensions==4.8.0 +tzlocal==5.2 +yarl==1.9.2 -git+https://github.com/chamburr/discord.py.git diff --git a/worker.py b/worker.py index 505ec6a4..dfd387c4 100644 --- a/worker.py +++ b/worker.py @@ -1,4 +1,5 @@ import asyncio +import discord import logging import sys @@ -14,13 +15,13 @@ version = sys.argv[4] logging.basicConfig(level=logging.INFO) -logger = logging.getLogger() +logger = logging.getLogger(f"cluster-{cluster_id}") logger.setLevel(logging.INFO) handler = logging.FileHandler(filename=f"logs/cluster-{cluster_id}.log", encoding="utf-8", mode="w") handler.setFormatter(logging.Formatter("%(asctime)s:%(levelname)s:%(name)s: %(message)s")) logger.addHandler(handler) +logger.info("Logging initialized") -log = logging.getLogger(__name__) async def command_prefix(bot2, message): @@ -38,6 +39,7 @@ async def command_prefix(bot2, message): cluster_id=cluster_id, cluster_count=cluster_count, version=version, + intents = discord.Intents.all() ) @@ -45,5 +47,5 @@ async def command_prefix(bot2, message): async def on_message(_): pass - +logger.info("Starting bot") loop.run_until_complete(bot.start())