|
143 | 143 | MaxMemory = [0] |
144 | 144 | MaxFreeMemory = [0] |
145 | 145 |
|
| 146 | +server_process: subprocess.Popen | None = None |
| 147 | +is_proxy = "KOBOLDCPP_SERVER" not in os.environ |
| 148 | +current_model = None |
| 149 | + |
146 | 150 | class logit_bias(ctypes.Structure): |
147 | 151 | _fields_ = [("token_id", ctypes.c_int32), |
148 | 152 | ("bias", ctypes.c_float)] |
@@ -1323,6 +1327,11 @@ def auto_set_backend_cli(): |
1323 | 1327 | print(f"Auto Selected Default Backend (flag={cpusupport})\n") |
1324 | 1328 |
|
1325 | 1329 | def load_model(model_filename): |
| 1330 | + if is_proxy: |
| 1331 | + current_model = model_filename |
| 1332 | + print("Deferred model loading.", current_model) |
| 1333 | + return True |
| 1334 | + |
1326 | 1335 | global args |
1327 | 1336 | inputs = load_model_inputs() |
1328 | 1337 | inputs.model_filename = model_filename.encode("UTF-8") |
@@ -3851,6 +3860,58 @@ def do_POST(self): |
3851 | 3860 | if args.foreground: |
3852 | 3861 | bring_terminal_to_foreground() |
3853 | 3862 |
|
| 3863 | + # Proxy |
| 3864 | + if is_proxy: |
| 3865 | + global server_process |
| 3866 | + global current_model |
| 3867 | + |
| 3868 | + model = genparams["model"] |
| 3869 | + if server_process is not None and current_model != model: |
| 3870 | + import psutil |
| 3871 | + parent = psutil.Process(server_process.pid) |
| 3872 | + processes = parent.children(recursive=True) + [parent] |
| 3873 | + for process in processes: |
| 3874 | + process.terminate() |
| 3875 | + for process in processes: |
| 3876 | + process.wait() |
| 3877 | + |
| 3878 | + server_process = None |
| 3879 | + |
| 3880 | + if server_process is None: |
| 3881 | + current_model = model |
| 3882 | + server_process = subprocess.Popen([sys.executable] + sys.argv + ["--port", str(args.port + 1), "--model", model], env={ |
| 3883 | + "KOBOLDCPP_SERVER": "True" |
| 3884 | + }) |
| 3885 | + |
| 3886 | + # Poke the server until it's alive |
| 3887 | + while True: |
| 3888 | + try: |
| 3889 | + with urllib.request.urlopen(urllib.request.Request(f"http://localhost:{args.port + 1}", method="HEAD"), timeout=1000) as response: |
| 3890 | + if response.status == 200: |
| 3891 | + break |
| 3892 | + |
| 3893 | + time.sleep(1) |
| 3894 | + except Exception: |
| 3895 | + time.sleep(1) |
| 3896 | + |
| 3897 | + |
| 3898 | + request = urllib.request.Request(f"http://localhost:{args.port + 1}" + self.path, data=body, headers=dict(self.headers), method="POST") |
| 3899 | + with urllib.request.urlopen(request) as response: |
| 3900 | + self.send_response_only(response.status) |
| 3901 | + for keyword, value in response.headers.items(): |
| 3902 | + self.send_header(keyword, value) |
| 3903 | + super(KcppServerRequestHandler, self).end_headers() |
| 3904 | + |
| 3905 | + while True: |
| 3906 | + chunk = response.read() |
| 3907 | + if not chunk: |
| 3908 | + break |
| 3909 | + self.wfile.write(chunk) |
| 3910 | + |
| 3911 | + self.wfile.flush() |
| 3912 | + self.close_connection = True |
| 3913 | + return |
| 3914 | + |
3854 | 3915 | if api_format > 0: #text gen |
3855 | 3916 | # Check if streaming chat completions, if so, set stream mode to true |
3856 | 3917 | if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]: |
|
0 commit comments