diff --git a/labml_db/index.py b/labml_db/index.py index c5f49a9..068cbf1 100644 --- a/labml_db/index.py +++ b/labml_db/index.py @@ -7,11 +7,12 @@ class Index(Generic[_KT]): - __db_drivers: Dict[str, IndexDbDriver] + __db_drivers: Dict[str, IndexDbDriver] = {} @staticmethod def set_db_drivers(db_drivers: List[IndexDbDriver]): - Index.__db_drivers = {d.index_name: d for d in db_drivers} + for d in db_drivers: + Index.__db_drivers[d.index_name] = d @classmethod def delete(cls, index_key: str): @@ -41,3 +42,9 @@ def mget(cls, index_key: List[str]) -> List[Optional[Key[_KT]]]: def set(cls, index_key: str, model_key: Key[_KT]): db_driver = Index.__db_drivers[cls.__name__] db_driver.set(index_key, str(model_key)) + + @classmethod + def get_all(cls): + db_driver = Index.__db_drivers[cls.__name__] + keys = db_driver.get_all() + return keys diff --git a/labml_db/index_driver/__init__.py b/labml_db/index_driver/__init__.py index 78b31eb..d9f1585 100644 --- a/labml_db/index_driver/__init__.py +++ b/labml_db/index_driver/__init__.py @@ -19,3 +19,6 @@ def mget(self, index_key: List[str]) -> List[str]: def set(self, index_key: str, model_key: str): raise NotImplementedError + + def get_all(self) -> List[str]: + raise NotImplementedError diff --git a/labml_db/index_driver/redis.py b/labml_db/index_driver/redis.py index e16a8bc..5213565 100644 --- a/labml_db/index_driver/redis.py +++ b/labml_db/index_driver/redis.py @@ -27,6 +27,8 @@ def get(self, index_key: str) -> str: def mget(self, index_key: List[str]) -> List[str]: return self._db.hmget(self._index_key, index_key) - def set(self, index_key: str, model_key: str): self._db.hset(self._index_key, index_key, model_key) + + def get_all(self): + return self._db.hgetall(self._index_key) diff --git a/labml_db/model.py b/labml_db/model.py index f67ddc9..3e0364e 100644 --- a/labml_db/model.py +++ b/labml_db/model.py @@ -168,7 +168,7 @@ def defaults(self) -> Dict[str, Primitive]: class Model(Generic[_KT]): __models: Dict[str, ModelSpec] = {} - __db_drivers: Dict[str, 'DbDriver'] + __db_drivers: Dict[str, 'DbDriver'] = {} _defaults = Dict[str, Primitive] _values = Dict[str, Primitive] @@ -239,7 +239,8 @@ def key(self) -> 'Key[_KT]': @staticmethod def set_db_drivers(db_drivers: List['DbDriver']): - Model.__db_drivers = {d.model_name: d for d in db_drivers} + for d in db_drivers: + Model.__db_drivers[d.model_name] = d @classmethod def mread_dict(cls, key: List[str], db_driver: Optional['DbDriver'] = None) -> List[ModelDict]: