Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 87 additions & 5 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@
MaxMemory = [0]
MaxFreeMemory = [0]

server_process: subprocess.Popen | None = None
is_proxy = "KOBOLDCPP_SERVER" not in os.environ
current_model = None

class logit_bias(ctypes.Structure):
_fields_ = [("token_id", ctypes.c_int32),
("bias", ctypes.c_float)]
Expand Down Expand Up @@ -1343,7 +1347,19 @@ def auto_set_backend_cli():
if not found_new_backend:
print(f"Auto Selected Default Backend (flag={cpusupport})\n")

def get_models():
if args.admin and args.admindir and os.path.exists(args.admindir):
from pathlib import Path
return [path for path in Path(args.admindir).iterdir() if (path.suffix in [".kcpps", ".kcppt", ".gguf"] and path.is_file())]
else:
return []

def load_model(model_filename):
if is_proxy:
current_model = model_filename
print("Deferred model loading.", current_model)
return True

global args
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
Expand Down Expand Up @@ -3260,9 +3276,8 @@ def do_GET(self):

elif self.path.endswith(('/api/admin/list_options')): #used by admin to get info about a kcpp instance
opts = []
if args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword):
dirpath = os.path.abspath(args.admindir)
opts = [f for f in sorted(os.listdir(dirpath)) if (f.endswith(".kcpps") or f.endswith(".kcppt") or f.endswith(".gguf")) and os.path.isfile(os.path.join(dirpath, f))]
if self.check_header_password(args.adminpassword):
opts = [path.name for path in get_models()]
opts.append("unload_model")
response_body = (json.dumps(opts).encode())

Expand Down Expand Up @@ -3332,7 +3347,7 @@ def do_GET(self):
response_body = (json.dumps({"logprobs":logprobsdict}).encode())

elif self.path.endswith('/v1/models'):
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":int(time.time()),"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
response_body = (json.dumps({"object":"list","data":[{"id":path.stem,"object":"model","created":path.stat().st_mtime,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"} for path in get_models()]}).encode())

elif self.path.endswith('/sdapi/v1/sd-models'):
if friendlysdmodelname=="inactive" or fullsdmodelpath=="":
Expand Down Expand Up @@ -3799,7 +3814,7 @@ def do_POST(self):
else:
dirpath = os.path.abspath(args.admindir)
targetfilepath = os.path.join(dirpath, targetfile)
opts = [f for f in os.listdir(dirpath) if (f.lower().endswith(".kcpps") or f.lower().endswith(".kcppt") or f.lower().endswith(".gguf")) and os.path.isfile(os.path.join(dirpath, f))]
opts = [str(path) for path in get_models()]
if targetfile in opts and os.path.exists(targetfilepath):
global_memory["restart_override_config_target"] = ""
if targetfile.lower().endswith(".gguf") and overrideconfig:
Expand Down Expand Up @@ -4020,6 +4035,73 @@ def do_POST(self):
if args.foreground:
bring_terminal_to_foreground()

# Proxy
if is_proxy:
global server_process
global current_model

model = genparams["model"]
model_path = next((str(path) for path in get_models() if path.stem == model), None)
if model_path is None:
self.send_response(404)
self.end_headers(content_type='application/json')
self.wfile.write(json.dumps({"detail": {
"error": "Model Not Found",
"msg": f"Model file {model} not found.",
"type": "bad_input",
}}).encode())
return

if server_process is None:
server_process = subprocess.Popen([sys.executable] + sys.argv + ["--port", str(args.port + 1), "--model", model_path], env={
"KOBOLDCPP_SERVER": "True"
})
elif current_model != model:
with urllib.request.urlopen(urllib.request.Request(f"http://localhost:{args.port + 1}/api/admin/reload_config", method="POST", data=json.dumps({"filename": model_path}).encode(), headers={
"Authorization": f"Bearer {args.adminpassword}"
})) as response:
if response.status != 200:
self.send_response(500)
self.end_headers(content_type='application/json')
self.wfile.write(json.dumps({"detail": {
"error": "Failed to switch model",
"msg": f"Failed to switch model to {model}.",
"type": "server_error",
}}).encode())
return

current_model = model

# Poke the server until it has the new model
while True:
try:
with urllib.request.urlopen(urllib.request.Request(f"http://localhost:{args.port + 1}/api/v1/model", method="GET"), timeout=1000) as response:
data = json.loads(response.read().decode())
if response.status == 200 and data.get("result") == f"koboldcpp/{model}":
break

time.sleep(1)
except Exception:
time.sleep(1)


request = urllib.request.Request(f"http://localhost:{args.port + 1}" + self.path, data=body, headers=dict(self.headers), method="POST")
with urllib.request.urlopen(request) as response:
self.send_response_only(response.status)
for keyword, value in response.headers.items():
self.send_header(keyword, value)
super(KcppServerRequestHandler, self).end_headers()

while True:
chunk = response.read()
if not chunk:
break
self.wfile.write(chunk)

self.wfile.flush()
self.close_connection = True
return

if api_format > 0: #text gen
# Check if streaming chat completions, if so, set stream mode to true
if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]:
Expand Down