diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs
index 3818741b..1bd88430 100644
--- a/Runtime/LLM.cs
+++ b/Runtime/LLM.cs
@@ -36,6 +36,9 @@ public class LLM : MonoBehaviour
/// log the output of the LLM in the Unity Editor.
[Tooltip("log the output of the LLM in the Unity Editor.")]
[LLM] public bool debug = false;
+ /// Wait for native debugger to connect to the backend
+ [Tooltip("Wait for native debugger to connect to the backend")]
+ [LLMAdvanced] public bool UseNativeDebugger = false;
/// number of prompts that can happen in parallel (-1 = number of LLMCaller objects)
[Tooltip("number of prompts that can happen in parallel (-1 = number of LLMCaller objects)")]
[LLMAdvanced] public int parallelPrompts = -1;
@@ -53,6 +56,8 @@ public class LLM : MonoBehaviour
public bool started { get; protected set; } = false;
/// Boolean set to true if the server has failed to start.
public bool failed { get; protected set; } = false;
+ /// Boolean set to true if the server has been destroyed.
+ public bool destroyed { get; protected set; } = false;
/// Boolean set to true if the models were not downloaded successfully.
public static bool modelSetupFailed { get; protected set; } = false;
/// Boolean set to true if the server has started and is ready to receive requests, false otherwise.
@@ -127,6 +132,13 @@ void OnValidate()
public async void Awake()
{
if (!enabled) return;
+ Load();
+ }
+
+ public async Awaitable Load()
+ {
+ await Awaitable.BackgroundThreadAsync();
+
#if !UNITY_EDITOR
modelSetupFailed = !await LLMManager.Setup();
#endif
@@ -142,9 +154,13 @@ public async void Awake()
failed = true;
return;
}
- await Task.Run(() => StartLLMServer(arguments));
+ await StartLLMServerAsync(arguments);
if (!started) return;
- if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject);
+ if (dontDestroyOnLoad)
+ {
+ await Awaitable.MainThreadAsync();
+ DontDestroyOnLoad(transform.root.gameObject);
+ }
}
///
@@ -476,10 +492,11 @@ private void StopLogging()
DestroyStreamWrapper(logStreamWrapper);
}
- private void StartLLMServer(string arguments)
+ private async Task StartLLMServerAsync(string arguments)
{
started = false;
failed = false;
+ destroyed = false;
bool useGPU = numGPULayers > 0;
foreach (string arch in LLMLib.PossibleArchitectures(useGPU))
@@ -488,6 +505,19 @@ private void StartLLMServer(string arguments)
try
{
InitLib(arch);
+#if UNITY_EDITOR
+ if (UseNativeDebugger)
+ {
+ if (llmlib?.LLM_IsDebuggerAttached == null)
+ {
+ LLMUnitySetup.Log($"Tried architecture: {arch} is not debug library");
+ Destroy();
+ continue;
+ }
+
+ await WaitNativeDebug();
+ }
+#endif
InitService(arguments);
LLMUnitySetup.Log($"Using architecture: {arch}");
break;
@@ -504,6 +534,7 @@ private void StartLLMServer(string arguments)
catch (Exception e)
{
error = $"{e.GetType()}: {e.Message}";
+ Destroy();
}
LLMUnitySetup.Log($"Tried architecture: {arch}, error: " + error);
}
@@ -514,20 +545,39 @@ private void StartLLMServer(string arguments)
return;
}
CallWithLock(StartService);
- LLMUnitySetup.Log("LLM service created");
+ if (started)
+ LLMUnitySetup.Log("LLM service created");
}
private void InitLib(string arch)
{
llmlib = new LLMLib(arch);
- CheckLLMStatus(false);
}
+#if UNITY_EDITOR
+ private async Task WaitNativeDebug()
+ {
+ if (llmlib?.LLM_IsDebuggerAttached != null)
+ {
+ LLMUnitySetup.Log("waiting debugger");
+ while (!destroyed)
+ {
+ if (llmlib.LLM_IsDebuggerAttached())
+ {
+ LLMUnitySetup.Log("remote debugger attached");
+ break;
+ }
+ await Task.Delay(100);
+ }
+ }
+ }
+#endif
+
void CallWithLock(EmptyCallback fn)
{
lock (startLock)
{
- if (llmlib == null) throw new DestroyException();
+ if (llmlib == null || destroyed) throw new DestroyException();
fn();
}
}
@@ -556,9 +606,12 @@ private void StartService()
{
llmThread = new Thread(() => llmlib.LLM_Start(LLMObject));
llmThread.Start();
- while (!llmlib.LLM_Started(LLMObject)) {}
- ApplyLoras();
- started = true;
+ while (!llmlib.LLM_Started(LLMObject) && !destroyed) { }
+ if (!destroyed)
+ {
+ ApplyLoras();
+ started = true;
+ }
}
///
@@ -611,6 +664,7 @@ void AssertStarted()
string error = null;
if (failed) error = "LLM service couldn't be created";
else if (!started) error = "LLM service not started";
+ else if (destroyed) error = "LLM service is being destroyed";
if (error != null)
{
LLMUnitySetup.LogError(error);
@@ -807,6 +861,7 @@ public void CancelRequest(int id_slot)
///
public void Destroy()
{
+ destroyed = true;
lock (staticLock)
lock (startLock)
{
diff --git a/Runtime/LLMCaller.cs b/Runtime/LLMCaller.cs
index 4b69137d..400d555f 100644
--- a/Runtime/LLMCaller.cs
+++ b/Runtime/LLMCaller.cs
@@ -238,6 +238,8 @@ protected virtual async Task PostRequestLocal(string json, string
// send a post request to the server and call the relevant callbacks to convert the received content and handle it
// this function has streaming functionality i.e. handles the answer while it is being received
while (!llm.failed && !llm.started) await Task.Yield();
+ if (llm.destroyed)
+ return default;
string callResult = null;
switch (endpoint)
{
diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs
index 0bf91a15..1f672af6 100644
--- a/Runtime/LLMCharacter.cs
+++ b/Runtime/LLMCharacter.cs
@@ -685,6 +685,8 @@ protected override async Task PostRequestLocal(string json, strin
if (endpoint != "completion") return await base.PostRequestLocal(json, endpoint, getContent, callback);
while (!llm.failed && !llm.started) await Task.Yield();
+ if (llm.destroyed)
+ return default;
string callResult = null;
bool callbackCalled = false;
diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs
index f7417453..e42c0ee1 100644
--- a/Runtime/LLMLib.cs
+++ b/Runtime/LLMLib.cs
@@ -497,6 +497,12 @@ public LLMLib(string arch)
StringWrapper_GetString = LibraryLoader.GetSymbolDelegate(libraryHandle, "StringWrapper_GetString");
Logging = LibraryLoader.GetSymbolDelegate(libraryHandle, "Logging");
StopLogging = LibraryLoader.GetSymbolDelegate(libraryHandle, "StopLogging");
+
+ // editor only
+#if UNITY_EDITOR
+ var symbol = LibraryLoader.GetSymbol(libraryHandle, "LLM_IsDebuggerAttached");
+ LLM_IsDebuggerAttached = (symbol != IntPtr.Zero) ? Marshal.GetDelegateForFunctionPointer(symbol) : null;
+#endif
}
///
@@ -606,6 +612,9 @@ public static string GetArchitecturePath(string arch)
public delegate void StringWrapper_DeleteDelegate(IntPtr instance);
public delegate int StringWrapper_GetStringSizeDelegate(IntPtr instance);
public delegate void StringWrapper_GetStringDelegate(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false);
+#if UNITY_EDITOR
+ public delegate bool LLM_IsDebuggerAttachedDelegate();
+#endif
public LoggingDelegate Logging;
public StopLoggingDelegate StopLogging;
@@ -631,6 +640,9 @@ public static string GetArchitecturePath(string arch)
public StringWrapper_DeleteDelegate StringWrapper_Delete;
public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize;
public StringWrapper_GetStringDelegate StringWrapper_GetString;
+#if UNITY_EDITOR
+ public LLM_IsDebuggerAttachedDelegate LLM_IsDebuggerAttached;
+#endif
#endif