diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 3818741b..1bd88430 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -36,6 +36,9 @@ public class LLM : MonoBehaviour /// log the output of the LLM in the Unity Editor. [Tooltip("log the output of the LLM in the Unity Editor.")] [LLM] public bool debug = false; + /// Wait for native debugger to connect to the backend + [Tooltip("Wait for native debugger to connect to the backend")] + [LLMAdvanced] public bool UseNativeDebugger = false; /// number of prompts that can happen in parallel (-1 = number of LLMCaller objects) [Tooltip("number of prompts that can happen in parallel (-1 = number of LLMCaller objects)")] [LLMAdvanced] public int parallelPrompts = -1; @@ -53,6 +56,8 @@ public class LLM : MonoBehaviour public bool started { get; protected set; } = false; /// Boolean set to true if the server has failed to start. public bool failed { get; protected set; } = false; + /// Boolean set to true if the server has been destroyed. + public bool destroyed { get; protected set; } = false; /// Boolean set to true if the models were not downloaded successfully. public static bool modelSetupFailed { get; protected set; } = false; /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. @@ -127,6 +132,13 @@ void OnValidate() public async void Awake() { if (!enabled) return; + Load(); + } + + public async Awaitable Load() + { + await Awaitable.BackgroundThreadAsync(); + #if !UNITY_EDITOR modelSetupFailed = !await LLMManager.Setup(); #endif @@ -142,9 +154,13 @@ public async void Awake() failed = true; return; } - await Task.Run(() => StartLLMServer(arguments)); + await StartLLMServerAsync(arguments); if (!started) return; - if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject); + if (dontDestroyOnLoad) + { + await Awaitable.MainThreadAsync(); + DontDestroyOnLoad(transform.root.gameObject); + } } /// @@ -476,10 +492,11 @@ private void StopLogging() DestroyStreamWrapper(logStreamWrapper); } - private void StartLLMServer(string arguments) + private async Task StartLLMServerAsync(string arguments) { started = false; failed = false; + destroyed = false; bool useGPU = numGPULayers > 0; foreach (string arch in LLMLib.PossibleArchitectures(useGPU)) @@ -488,6 +505,19 @@ private void StartLLMServer(string arguments) try { InitLib(arch); +#if UNITY_EDITOR + if (UseNativeDebugger) + { + if (llmlib?.LLM_IsDebuggerAttached == null) + { + LLMUnitySetup.Log($"Tried architecture: {arch} is not debug library"); + Destroy(); + continue; + } + + await WaitNativeDebug(); + } +#endif InitService(arguments); LLMUnitySetup.Log($"Using architecture: {arch}"); break; @@ -504,6 +534,7 @@ private void StartLLMServer(string arguments) catch (Exception e) { error = $"{e.GetType()}: {e.Message}"; + Destroy(); } LLMUnitySetup.Log($"Tried architecture: {arch}, error: " + error); } @@ -514,20 +545,39 @@ private void StartLLMServer(string arguments) return; } CallWithLock(StartService); - LLMUnitySetup.Log("LLM service created"); + if (started) + LLMUnitySetup.Log("LLM service created"); } private void InitLib(string arch) { llmlib = new LLMLib(arch); - CheckLLMStatus(false); } +#if UNITY_EDITOR + private async Task WaitNativeDebug() + { + if (llmlib?.LLM_IsDebuggerAttached != null) + { + LLMUnitySetup.Log("waiting debugger"); + while (!destroyed) + { + if (llmlib.LLM_IsDebuggerAttached()) + { + LLMUnitySetup.Log("remote debugger attached"); + break; + } + await Task.Delay(100); + } + } + } +#endif + void CallWithLock(EmptyCallback fn) { lock (startLock) { - if (llmlib == null) throw new DestroyException(); + if (llmlib == null || destroyed) throw new DestroyException(); fn(); } } @@ -556,9 +606,12 @@ private void StartService() { llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); llmThread.Start(); - while (!llmlib.LLM_Started(LLMObject)) {} - ApplyLoras(); - started = true; + while (!llmlib.LLM_Started(LLMObject) && !destroyed) { } + if (!destroyed) + { + ApplyLoras(); + started = true; + } } /// @@ -611,6 +664,7 @@ void AssertStarted() string error = null; if (failed) error = "LLM service couldn't be created"; else if (!started) error = "LLM service not started"; + else if (destroyed) error = "LLM service is being destroyed"; if (error != null) { LLMUnitySetup.LogError(error); @@ -807,6 +861,7 @@ public void CancelRequest(int id_slot) /// public void Destroy() { + destroyed = true; lock (staticLock) lock (startLock) { diff --git a/Runtime/LLMCaller.cs b/Runtime/LLMCaller.cs index 4b69137d..400d555f 100644 --- a/Runtime/LLMCaller.cs +++ b/Runtime/LLMCaller.cs @@ -238,6 +238,8 @@ protected virtual async Task PostRequestLocal(string json, string // send a post request to the server and call the relevant callbacks to convert the received content and handle it // this function has streaming functionality i.e. handles the answer while it is being received while (!llm.failed && !llm.started) await Task.Yield(); + if (llm.destroyed) + return default; string callResult = null; switch (endpoint) { diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 0bf91a15..1f672af6 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -685,6 +685,8 @@ protected override async Task PostRequestLocal(string json, strin if (endpoint != "completion") return await base.PostRequestLocal(json, endpoint, getContent, callback); while (!llm.failed && !llm.started) await Task.Yield(); + if (llm.destroyed) + return default; string callResult = null; bool callbackCalled = false; diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index f7417453..e42c0ee1 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -497,6 +497,12 @@ public LLMLib(string arch) StringWrapper_GetString = LibraryLoader.GetSymbolDelegate(libraryHandle, "StringWrapper_GetString"); Logging = LibraryLoader.GetSymbolDelegate(libraryHandle, "Logging"); StopLogging = LibraryLoader.GetSymbolDelegate(libraryHandle, "StopLogging"); + + // editor only +#if UNITY_EDITOR + var symbol = LibraryLoader.GetSymbol(libraryHandle, "LLM_IsDebuggerAttached"); + LLM_IsDebuggerAttached = (symbol != IntPtr.Zero) ? Marshal.GetDelegateForFunctionPointer(symbol) : null; +#endif } /// @@ -606,6 +612,9 @@ public static string GetArchitecturePath(string arch) public delegate void StringWrapper_DeleteDelegate(IntPtr instance); public delegate int StringWrapper_GetStringSizeDelegate(IntPtr instance); public delegate void StringWrapper_GetStringDelegate(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false); +#if UNITY_EDITOR + public delegate bool LLM_IsDebuggerAttachedDelegate(); +#endif public LoggingDelegate Logging; public StopLoggingDelegate StopLogging; @@ -631,6 +640,9 @@ public static string GetArchitecturePath(string arch) public StringWrapper_DeleteDelegate StringWrapper_Delete; public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize; public StringWrapper_GetStringDelegate StringWrapper_GetString; +#if UNITY_EDITOR + public LLM_IsDebuggerAttachedDelegate LLM_IsDebuggerAttached; +#endif #endif