Skip to content

Commit 47a5291

Browse files
committed
added a proxy
1 parent cdda9d1 commit 47a5291

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

koboldcpp.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@
143143
MaxMemory = [0]
144144
MaxFreeMemory = [0]
145145

146+
server_process: subprocess.Popen | None = None
147+
is_proxy = "KOBOLDCPP_SERVER" not in os.environ
148+
current_model = None
149+
146150
class logit_bias(ctypes.Structure):
147151
_fields_ = [("token_id", ctypes.c_int32),
148152
("bias", ctypes.c_float)]
@@ -1323,6 +1327,11 @@ def auto_set_backend_cli():
13231327
print(f"Auto Selected Default Backend (flag={cpusupport})\n")
13241328

13251329
def load_model(model_filename):
1330+
if is_proxy:
1331+
current_model = model_filename
1332+
print("Deferred model loading.", current_model)
1333+
return True
1334+
13261335
global args
13271336
inputs = load_model_inputs()
13281337
inputs.model_filename = model_filename.encode("UTF-8")
@@ -3851,6 +3860,58 @@ def do_POST(self):
38513860
if args.foreground:
38523861
bring_terminal_to_foreground()
38533862

3863+
# Proxy
3864+
if is_proxy:
3865+
global server_process
3866+
global current_model
3867+
3868+
model = genparams["model"]
3869+
if server_process is not None and current_model != model:
3870+
import psutil
3871+
parent = psutil.Process(server_process.pid)
3872+
processes = parent.children(recursive=True) + [parent]
3873+
for process in processes:
3874+
process.terminate()
3875+
for process in processes:
3876+
process.wait()
3877+
3878+
server_process = None
3879+
3880+
if server_process is None:
3881+
current_model = model
3882+
server_process = subprocess.Popen([sys.executable] + sys.argv + ["--port", str(args.port + 1), "--model", model], env={
3883+
"KOBOLDCPP_SERVER": "True"
3884+
})
3885+
3886+
# Poke the server until it's alive
3887+
while True:
3888+
try:
3889+
with urllib.request.urlopen(urllib.request.Request(f"http://localhost:{args.port + 1}", method="HEAD"), timeout=1000) as response:
3890+
if response.status == 200:
3891+
break
3892+
3893+
time.sleep(1)
3894+
except Exception:
3895+
time.sleep(1)
3896+
3897+
3898+
request = urllib.request.Request(f"http://localhost:{args.port + 1}" + self.path, data=body, headers=dict(self.headers), method="POST")
3899+
with urllib.request.urlopen(request) as response:
3900+
self.send_response_only(response.status)
3901+
for keyword, value in response.headers.items():
3902+
self.send_header(keyword, value)
3903+
super(KcppServerRequestHandler, self).end_headers()
3904+
3905+
while True:
3906+
chunk = response.read()
3907+
if not chunk:
3908+
break
3909+
self.wfile.write(chunk)
3910+
3911+
self.wfile.flush()
3912+
self.close_connection = True
3913+
return
3914+
38543915
if api_format > 0: #text gen
38553916
# Check if streaming chat completions, if so, set stream mode to true
38563917
if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]:

0 commit comments

Comments
 (0)