Skip to content

Commit 90e78bb

Browse files
authored
Implement better extension/library watchdog (#58)
1 parent b9ae30b commit 90e78bb

File tree

6 files changed

+109
-54
lines changed

6 files changed

+109
-54
lines changed

bot/libs/utils/logger.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __enter__(self) -> None:
1717
max_bytes = 32 * 1024 * 1024 # 32 MiB
1818
self.log.setLevel(logging.INFO)
1919
logging.getLogger("discord").setLevel(logging.INFO)
20+
logging.getLogger("watchfiles").setLevel(logging.WARNING)
2021
handler = RotatingFileHandler(
2122
filename="rodhaj.log",
2223
encoding="utf-8",

bot/libs/utils/reloader.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import importlib
5+
import os
6+
import sys
7+
from pathlib import Path
8+
from typing import TYPE_CHECKING, Optional
9+
10+
from discord.ext import commands
11+
from watchfiles import Change, awatch
12+
13+
if TYPE_CHECKING:
14+
from rodhaj import Rodhaj
15+
16+
17+
class Reloader:
18+
"""An watchdog for reloading extensions and library files
19+
20+
This reloads/unloads extensions, and also reloads library modules.
21+
This does not implement a deep reload, as there is no way to do so
22+
that way.
23+
"""
24+
25+
def __init__(self, bot: Rodhaj, path: Path):
26+
self.bot = bot
27+
self.path = path
28+
29+
self.loop = asyncio.get_running_loop()
30+
self.logger = bot.logger
31+
self._cogs_path = self.path / "cogs"
32+
self._libs_path = self.path / "libs"
33+
34+
### Finding modules from the path directly
35+
36+
def find_modules_from_path(self, path: str) -> Optional[str]:
37+
root, ext = os.path.splitext(path)
38+
sys_path_index = len(sys.path[0].split("/"))
39+
if ext != ".py":
40+
return
41+
42+
local_path = root.split("/")[sys_path_index:]
43+
return ".".join(item for item in local_path)
44+
45+
### Loading/reloading extensions and library modules
46+
47+
async def reload_or_load_extension(self, module: str) -> None:
48+
try:
49+
await self.bot.reload_extension(module)
50+
self.logger.info("Reloaded extension: %s", module)
51+
except commands.ExtensionNotLoaded:
52+
await self.bot.load_extension(module)
53+
self.logger.info("Loaded extension: %s", module)
54+
55+
async def reload_library(self, module: str) -> None:
56+
try:
57+
actual_module = sys.modules[module]
58+
importlib.reload(actual_module)
59+
self.logger.info("Reloaded lib module: %s", module)
60+
except KeyError:
61+
self.logger.warning("Failed to reload module %s. Does it exist?", module)
62+
63+
async def reload_extension_or_library(self, module: str) -> None:
64+
if module.startswith("libs"):
65+
await self.reload_library(module)
66+
elif module.startswith("cogs"):
67+
await self.reload_or_load_extension(module)
68+
69+
### Internal coroutine to start the watch
70+
71+
async def _start(self) -> None:
72+
async for changes in awatch(self._cogs_path, self._libs_path):
73+
for ctype, cpath in changes:
74+
module = self.find_modules_from_path(cpath)
75+
if module is None:
76+
continue
77+
78+
if ctype == Change.modified or ctype == Change.added:
79+
await self.reload_extension_or_library(module)
80+
elif ctype == Change.deleted:
81+
await self.bot.unload_extension(module)
82+
83+
### Public method to start the reloader
84+
85+
def start(self) -> None:
86+
"""Starts the deep reloader"""
87+
self.loop.create_task(self._start())
88+
self.bot.dispatch("deepreloader_ready")

bot/rodhaj.py

+5-20
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,11 @@
1919
RodhajHelp,
2020
send_error_embed,
2121
)
22+
from libs.utils.reloader import Reloader
2223

2324
if TYPE_CHECKING:
2425
from cogs.tickets import Tickets
2526

26-
_fsw = True
27-
try:
28-
from watchfiles import awatch
29-
except ImportError:
30-
_fsw = False
31-
3227
TRANSPROGRAMMER_GUILD_ID = 1183302385020436480
3328

3429

@@ -65,6 +60,7 @@ def __init__(
6560
self.transprogrammer_guild_id = TRANSPROGRAMMER_GUILD_ID
6661
self.version = str(VERSION)
6762
self._dev_mode = dev_mode
63+
self._reloader = Reloader(self, Path(__file__).parent)
6864

6965
### Ticket related utils
7066
async def fetch_partial_config(self) -> Optional[PartialConfig]:
@@ -196,17 +192,6 @@ async def on_message(self, message: discord.Message) -> None:
196192
return
197193
await self.process_commands(message, ctx)
198194

199-
### Dev related utils
200-
201-
async def fs_watcher(self) -> None:
202-
cogs_path = Path(__file__).parent.joinpath("cogs")
203-
async for changes in awatch(cogs_path):
204-
changes_list = list(changes)[0]
205-
if changes_list[0].modified == 2:
206-
reload_file = Path(changes_list[1])
207-
self.logger.info(f"Reloading extension: {reload_file.name[:-3]}")
208-
await self.reload_extension(f"cogs.{reload_file.name[:-3]}")
209-
210195
### Internal core overrides
211196

212197
async def setup_hook(self) -> None:
@@ -219,9 +204,9 @@ async def setup_hook(self) -> None:
219204

220205
self.partial_config = await self.fetch_partial_config()
221206

222-
if self._dev_mode is True and _fsw is True:
223-
self.logger.info("Dev mode is enabled. Loading FSWatcher")
224-
self.loop.create_task(self.fs_watcher())
207+
if self._dev_mode:
208+
self.logger.info("Dev mode is enabled. Loading Reloader")
209+
self._reloader.start()
225210

226211
async def on_ready(self):
227212
if not hasattr(self, "uptime"):

poetry.lock

+12-32
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ environs = "^10.3.0"
2222
async-lru = "^2.0.4"
2323
msgspec = "^0.18.6"
2424
jishaku = "^2.5.2"
25+
watchfiles = "^0.21.0"
2526

2627
[tool.poetry.group.dev.dependencies]
2728
# These are pinned by major version
@@ -30,7 +31,6 @@ jishaku = "^2.5.2"
3031
pre-commit = "^3"
3132
pyright = "^1.1"
3233
ruff = "^0.1"
33-
watchfiles = "^0"
3434

3535
[tool.poetry.group.docs.dependencies]
3636
sphinx = "^7.2.6"

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ typing-extensions==4.9.0
1111
environs==10.3.0
1212
async-lru==2.0.4
1313
msgspec==0.18.6
14-
jishaku==2.5.2
14+
jishaku==2.5.2
15+
watchfiles>=0.21.0,<1

0 commit comments

Comments
 (0)