Skip to content
Merged
5 changes: 2 additions & 3 deletions graphgen/bases/base_generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from graphgen.bases.base_llm_client import BaseLLMClient


@dataclass
class BaseGenerator(ABC):
"""
Generate QAs based on given prompts.
"""

llm_client: BaseLLMClient
def __init__(self, llm_client: BaseLLMClient):
self.llm_client = llm_client

@staticmethod
@abstractmethod
Expand Down
12 changes: 4 additions & 8 deletions graphgen/bases/base_kg_builder.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

from graphgen.bases.base_llm_client import BaseLLMClient
from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk


@dataclass
class BaseKGBuilder(ABC):
llm_client: BaseLLMClient

_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
_edges: Dict[Tuple[str, str], List[dict]] = field(
default_factory=lambda: defaultdict(list)
)
def __init__(self, llm_client: BaseLLMClient):
self.llm_client = llm_client
self._nodes: Dict[str, List[dict]] = defaultdict(list)
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)

@abstractmethod
async def extract(
Expand Down
2 changes: 0 additions & 2 deletions graphgen/bases/base_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List

from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Community


@dataclass
class BasePartitioner(ABC):
@abstractmethod
async def partition(
Expand Down
23 changes: 15 additions & 8 deletions graphgen/bases/base_splitter.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
import copy
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Literal, Optional, Union

from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger


@dataclass
class BaseSplitter(ABC):
"""
Abstract base class for splitting text into smaller chunks.
"""

chunk_size: int = 1024
chunk_overlap: int = 100
length_function: Callable[[str], int] = len
keep_separator: bool = False
add_start_index: bool = False
strip_whitespace: bool = True
def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 100,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
self.keep_separator = keep_separator
self.add_start_index = add_start_index
self.strip_whitespace = strip_whitespace

@abstractmethod
def split_text(self, text: str) -> List[str]:
Expand Down
10 changes: 3 additions & 7 deletions graphgen/bases/base_storage.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from dataclasses import dataclass
from typing import Generic, TypeVar, Union

T = TypeVar("T")


@dataclass
class StorageNameSpace:
working_dir: str = None
namespace: str = None
def __init__(self, working_dir: str = None, namespace: str = None):
self.working_dir = working_dir
self.namespace = namespace

async def index_done_callback(self):
"""commit the storage operations after indexing"""
Expand All @@ -16,7 +15,6 @@ async def query_done_callback(self):
"""commit the storage operations after querying"""


@dataclass
class BaseListStorage(Generic[T], StorageNameSpace):
async def all_items(self) -> list[T]:
raise NotImplementedError
Expand All @@ -34,7 +32,6 @@ async def drop(self):
raise NotImplementedError


@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:
raise NotImplementedError
Expand All @@ -58,7 +55,6 @@ async def drop(self):
raise NotImplementedError


@dataclass
class BaseGraphStorage(StorageNameSpace):
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
Expand Down
5 changes: 2 additions & 3 deletions graphgen/bases/base_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List


@dataclass
class BaseTokenizer(ABC):
model_name: str = "cl100k_base"
def __init__(self, model_name: str = "cl100k_base"):
self.model_name = model_name

@abstractmethod
def encode(self, text: str) -> List[int]:
Expand Down
7 changes: 3 additions & 4 deletions graphgen/models/evaluator/base_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import asyncio
from dataclasses import dataclass

from tqdm.asyncio import tqdm as tqdm_async

from graphgen.bases.datatypes import QAPair
from graphgen.utils import create_event_loop


@dataclass
class BaseEvaluator:
max_concurrent: int = 100
results: list[float] = None
def __init__(self, max_concurrent: int = 100):
self.max_concurrent = max_concurrent
self.results: list[float] = None

def evaluate(self, pairs: list[QAPair]) -> list[float]:
"""
Expand Down
36 changes: 24 additions & 12 deletions graphgen/models/llm/topk_token_model.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import List, Optional

from graphgen.bases import Token


@dataclass
class TopkTokenModel:
do_sample: bool = False
temperature: float = 0
max_tokens: int = 4096
repetition_penalty: float = 1.05
num_beams: int = 1
topk: int = 50
topp: float = 0.95

topk_per_token: int = 5 # number of topk tokens to generate for each token
class TopkTokenModel(ABC):
def __init__(
self,
do_sample: bool = False,
temperature: float = 0,
max_tokens: int = 4096,
repetition_penalty: float = 1.05,
num_beams: int = 1,
topk: int = 50,
topp: float = 0.95,
topk_per_token: int = 5,
):
self.do_sample = do_sample
self.temperature = temperature
self.max_tokens = max_tokens
self.repetition_penalty = repetition_penalty
self.num_beams = num_beams
self.topk = topk
self.topp = topp
self.topk_per_token = topk_per_token

@abstractmethod
async def generate_topk_per_token(self, text: str) -> List[Token]:
"""
Generate prob, text and candidates for each token of the model's output.
This function is used to visualize the inference process.
"""
raise NotImplementedError

@abstractmethod
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None
) -> List[Token]:
Expand All @@ -32,6 +43,7 @@ async def generate_inputs_prob(
"""
raise NotImplementedError

@abstractmethod
async def generate_answer(
self, text: str, history: Optional[List[str]] = None
) -> str:
Expand Down
3 changes: 0 additions & 3 deletions graphgen/models/search/db/uniprot_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from dataclasses import dataclass

import requests
from fastapi import HTTPException

Expand All @@ -8,7 +6,6 @@
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"


@dataclass
class UniProtSearch:
"""
UniProt Search client to search with UniProt.
Expand Down
2 changes: 0 additions & 2 deletions graphgen/models/search/kg/wiki_search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import List, Union

import wikipedia
Expand All @@ -7,7 +6,6 @@
from graphgen.utils import detect_main_language, logger


@dataclass
class WikiSearch:
@staticmethod
def set_language(language: str):
Expand Down
6 changes: 2 additions & 4 deletions graphgen/models/search/web/bing_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from dataclasses import dataclass

import requests
from fastapi import HTTPException

Expand All @@ -9,13 +7,13 @@
BING_MKT = "en-US"


@dataclass
class BingSearch:
"""
Bing Search client to search with Bing.
"""

subscription_key: str
def __init__(self, subscription_key: str):
self.subscription_key = subscription_key

def search(self, query: str, num_results: int = 1):
"""
Expand Down
3 changes: 0 additions & 3 deletions graphgen/models/search/web/google_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from dataclasses import dataclass

import requests
from fastapi import HTTPException

Expand All @@ -8,7 +6,6 @@
GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"


@dataclass
class GoogleSearch:
def __init__(self, subscription_key: str, cx: str):
"""
Expand Down
4 changes: 4 additions & 0 deletions graphgen/models/storage/json_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

@dataclass
class JsonKVStorage(BaseKVStorage):
working_dir: str = None
namespace: str = None
_data: dict[str, str] = None

def __post_init__(self):
Expand Down Expand Up @@ -53,6 +55,8 @@ async def drop(self):

@dataclass
class JsonListStorage(BaseListStorage):
working_dir: str = None
namespace: str = None
_data: list = None

def __post_init__(self):
Expand Down
3 changes: 3 additions & 0 deletions graphgen/models/storage/networkx_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

@dataclass
class NetworkXStorage(BaseGraphStorage):
working_dir: str = None
namespace: str = None

@staticmethod
def load_nx_graph(file_name) -> Optional[nx.Graph]:
if os.path.exists(file_name):
Expand Down
Loading