-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathCredentialManager.cs
258 lines (218 loc) · 8.86 KB
/
CredentialManager.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Coder.Desktop.App.Models;
using Coder.Desktop.Vpn.Utilities;
using CoderSdk;
namespace Coder.Desktop.App.Services;
public class RawCredentials
{
public required string CoderUrl { get; set; }
public required string ApiToken { get; set; }
}
[JsonSerializable(typeof(RawCredentials))]
public partial class RawCredentialsJsonContext : JsonSerializerContext
{
}
public interface ICredentialManager
{
public event EventHandler<CredentialModel> CredentialsChanged;
public CredentialModel GetCredentials();
public Task SetCredentials(string coderUrl, string apiToken, CancellationToken ct = default);
public void ClearCredentials();
}
public class CredentialManager : ICredentialManager
{
private const string CredentialsTargetName = "Coder.Desktop.App.Credentials";
private readonly RaiiSemaphoreSlim _lock = new(1, 1);
private CredentialModel? _latestCredentials;
public event EventHandler<CredentialModel>? CredentialsChanged;
public CredentialModel GetCredentials()
{
using var _ = _lock.Lock();
if (_latestCredentials != null) return _latestCredentials.Clone();
var rawCredentials = ReadCredentials();
if (rawCredentials is null)
_latestCredentials = new CredentialModel
{
State = CredentialState.Invalid,
};
else
_latestCredentials = new CredentialModel
{
State = CredentialState.Valid,
CoderUrl = rawCredentials.CoderUrl,
ApiToken = rawCredentials.ApiToken,
};
return _latestCredentials.Clone();
}
public async Task SetCredentials(string coderUrl, string apiToken, CancellationToken ct = default)
{
if (string.IsNullOrWhiteSpace(coderUrl)) throw new ArgumentException("Coder URL is required", nameof(coderUrl));
coderUrl = coderUrl.Trim();
if (coderUrl.Length > 128) throw new ArgumentOutOfRangeException(nameof(coderUrl), "Coder URL is too long");
if (!Uri.TryCreate(coderUrl, UriKind.Absolute, out var uri))
throw new ArgumentException($"Coder URL '{coderUrl}' is not a valid URL", nameof(coderUrl));
if (uri.PathAndQuery != "/") throw new ArgumentException("Coder URL must be the root URL", nameof(coderUrl));
if (string.IsNullOrWhiteSpace(apiToken)) throw new ArgumentException("API token is required", nameof(apiToken));
apiToken = apiToken.Trim();
if (apiToken.Length != 33)
throw new ArgumentOutOfRangeException(nameof(apiToken), "API token must be 33 characters long");
try
{
var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(TimeSpan.FromSeconds(15));
var sdkClient = new CoderApiClient(uri);
sdkClient.SetSessionToken(apiToken);
// TODO: we should probably perform a version check here too,
// rather than letting the service do it on Start
_ = await sdkClient.GetBuildInfo(cts.Token);
_ = await sdkClient.GetUser(User.Me, cts.Token);
}
catch (Exception e)
{
throw new InvalidOperationException("Could not connect to or verify Coder server", e);
}
WriteCredentials(new RawCredentials
{
CoderUrl = coderUrl,
ApiToken = apiToken,
});
UpdateState(new CredentialModel
{
State = CredentialState.Valid,
CoderUrl = coderUrl,
ApiToken = apiToken,
});
}
public void ClearCredentials()
{
NativeApi.DeleteCredentials(CredentialsTargetName);
UpdateState(new CredentialModel
{
State = CredentialState.Invalid,
CoderUrl = null,
ApiToken = null,
});
}
private void UpdateState(CredentialModel newModel)
{
using (_lock.Lock())
{
_latestCredentials = newModel.Clone();
}
CredentialsChanged?.Invoke(this, newModel.Clone());
}
private static RawCredentials? ReadCredentials()
{
var raw = NativeApi.ReadCredentials(CredentialsTargetName);
if (raw == null) return null;
RawCredentials? credentials;
try
{
credentials = JsonSerializer.Deserialize(raw, RawCredentialsJsonContext.Default.RawCredentials);
}
catch (JsonException)
{
return null;
}
if (credentials is null || string.IsNullOrWhiteSpace(credentials.CoderUrl) ||
string.IsNullOrWhiteSpace(credentials.ApiToken)) return null;
return credentials;
}
private static void WriteCredentials(RawCredentials credentials)
{
var raw = JsonSerializer.Serialize(credentials, RawCredentialsJsonContext.Default.RawCredentials);
NativeApi.WriteCredentials(CredentialsTargetName, raw);
}
private static class NativeApi
{
private const int CredentialTypeGeneric = 1;
private const int PersistenceTypeLocalComputer = 2;
private const int ErrorNotFound = 1168;
private const int CredMaxCredentialBlobSize = 5 * 512;
public static string? ReadCredentials(string targetName)
{
if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr))
{
var error = Marshal.GetLastWin32Error();
if (error == ErrorNotFound) return null;
throw new InvalidOperationException($"Failed to read credentials (Error {error})");
}
try
{
var cred = Marshal.PtrToStructure<CREDENTIAL>(credentialPtr);
return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char));
}
finally
{
CredFree(credentialPtr);
}
}
public static void WriteCredentials(string targetName, string secret)
{
var byteCount = Encoding.Unicode.GetByteCount(secret);
if (byteCount > CredMaxCredentialBlobSize)
throw new ArgumentOutOfRangeException(nameof(secret),
$"The secret is greater than {CredMaxCredentialBlobSize} bytes");
var credentialBlob = Marshal.StringToHGlobalUni(secret);
var cred = new CREDENTIAL
{
Type = CredentialTypeGeneric,
TargetName = targetName,
CredentialBlobSize = byteCount,
CredentialBlob = credentialBlob,
Persist = PersistenceTypeLocalComputer,
};
try
{
if (!CredWriteW(ref cred, 0))
{
var error = Marshal.GetLastWin32Error();
throw new InvalidOperationException($"Failed to write credentials (Error {error})");
}
}
finally
{
Marshal.FreeHGlobal(credentialBlob);
}
}
public static void DeleteCredentials(string targetName)
{
if (!CredDeleteW(targetName, CredentialTypeGeneric, 0))
{
var error = Marshal.GetLastWin32Error();
if (error == ErrorNotFound) return;
throw new InvalidOperationException($"Failed to delete credentials (Error {error})");
}
}
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr);
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredWriteW([In] ref CREDENTIAL userCredential, [In] uint flags);
[DllImport("Advapi32.dll", SetLastError = true)]
private static extern void CredFree([In] IntPtr cred);
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredDeleteW(string target, int type, int flags);
[StructLayout(LayoutKind.Sequential)]
private struct CREDENTIAL
{
public int Flags;
public int Type;
[MarshalAs(UnmanagedType.LPWStr)] public string TargetName;
[MarshalAs(UnmanagedType.LPWStr)] public string Comment;
public long LastWritten;
public int CredentialBlobSize;
public IntPtr CredentialBlob;
public int Persist;
public int AttributeCount;
public IntPtr Attributes;
[MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias;
[MarshalAs(UnmanagedType.LPWStr)] public string UserName;
}
}
}