diff --git a/SnaffCore/Concurrency/BlockingTaskScheduler.cs b/SnaffCore/Concurrency/BlockingTaskScheduler.cs index f748bd3a..47d847cb 100644 --- a/SnaffCore/Concurrency/BlockingTaskScheduler.cs +++ b/SnaffCore/Concurrency/BlockingTaskScheduler.cs @@ -1,8 +1,13 @@ using System; using System.Collections.Generic; +using System.IO; +using System.Linq; using System.Numerics; +using System.Text; using System.Threading; using System.Threading.Tasks; +using System.Timers; +using static SnaffCore.Config.Options; namespace SnaffCore.Concurrency { @@ -34,6 +39,8 @@ public bool Done() Scheduler.RecalculateCounters(); TaskCounters taskCounters = Scheduler.GetTaskCounters(); + Console.WriteLine($"Checking if done - queued: {taskCounters.CurrentTasksQueued}, done: {taskCounters.CurrentTasksRunning}"); + if ((taskCounters.CurrentTasksQueued + taskCounters.CurrentTasksRunning == 0)) { return true; @@ -68,6 +75,168 @@ public void New(Action action) } } + public enum TaskFileType + { + None = 0, + Share = 1, + Tree = 2, + File = 3 + } + + public enum TaskFileEntryStatus + { + Pending = 0, + Completed = 1, + } + + public struct TaskFileEntry + { + public TaskFileEntryStatus status; + public string guid; + public TaskFileType type; + public string input; + + public override string ToString() + { + StringBuilder stringBuilder = new StringBuilder(); + + stringBuilder.Append(status.ToString()); + stringBuilder.Append("|"); + stringBuilder.Append(guid); + stringBuilder.Append("|"); + stringBuilder.Append(type.ToString()); + stringBuilder.Append("|"); + stringBuilder.Append(input); + + return stringBuilder.ToString(); + } + + public TaskFileEntry(TaskFileType type, string input) + { + guid = Guid.NewGuid().ToString(); + status = TaskFileEntryStatus.Pending; + this.type = type; + this.input = input; + } + + public TaskFileEntry(string entryLine) + { + string[] lineParts = entryLine.Split('|'); + + status = (TaskFileEntryStatus)Enum.Parse(typeof(TaskFileEntryStatus), lineParts[0]); + guid = lineParts[1]; + + type = (TaskFileType)Enum.Parse(typeof(TaskFileType), lineParts[2]); + input = lineParts[3]; + } + } + + public class ResumingTaskScheduler : BlockingStaticTaskScheduler + { + private static readonly Dictionary> pendingTasks = new Dictionary>(); + private static readonly object writeLock = new object(); + private static int pendingSaveCalls = 0; + + internal BlockingMq Mq { get; } + + public ResumingTaskScheduler(int threads, int maxBacklog) : base(threads, maxBacklog) + { + this.Mq = BlockingMq.GetMq(); + } + + public void New(string taskType, Action action, string path) + { + string guid = null; + + if (MyOptions.TaskFile != null) + { + guid = Guid.NewGuid().ToString(); + pendingTasks.Add(guid, new Tuple(taskType, path)); + } + + New(() => + { + try + { + action(path); + } + catch (Exception e) + { + Mq.Error("Exception in " + taskType.ToString() + " task for host " + path); + Mq.Error(e.ToString()); + } + + if (guid != null) pendingTasks.Remove(guid); + }); + } + + public static void SaveState(object sender, ElapsedEventArgs e) + { + SaveState(); + } + + public static void SaveState() + { + // Guard against the possibility that someone forgot to check this + if (MyOptions.TaskFile == null) return; + + // This blocks more than one save call from being buffered at a time + // Prevents a situation where a bunch of buffered calls wait for the lock + // But still allows for you to "write continously" if you set an interval shorter than the file write time + if (pendingSaveCalls > 1) return; + pendingSaveCalls++; + + // In case the file takes longer to write than the set interval + lock (writeLock) + { + using (StreamWriter fileWriter = new StreamWriter(MyOptions.TaskFile, false)) + { + // Copy the values into the array to avoid an error in case the pending tasks are changed during the write loop + Tuple[] valuesSnapshot = pendingTasks.Values.ToArray(); + + foreach (Tuple value in valuesSnapshot) + { + fileWriter.WriteLine($"{value.Item1}|{value.Item2}"); + } + + fileWriter.Flush(); + } + + pendingSaveCalls--; + } + } + } + + public class ShareTaskScheduler : ResumingTaskScheduler + { + public ShareTaskScheduler(int threads, int maxBacklog) : base(threads, maxBacklog) { } + + public void New(Action action, string share) + { + New("share", action, share); + } + } + + public class TreeTaskScheduler : ResumingTaskScheduler + { + public TreeTaskScheduler(int threads, int maxBacklog) : base(threads, maxBacklog) { } + + public void New(Action action, string tree) + { + New("tree", action, tree); + } + } + + public class FileTaskScheduler : ResumingTaskScheduler + { + public FileTaskScheduler(int threads, int maxBacklog) : base(threads, maxBacklog) { } + + public void New(Action action, string file) + { + New("file", action, file); + } + } + public class TaskCounters { public BigInteger TotalTasksQueued { get; set; } diff --git a/SnaffCore/Config/Options.cs b/SnaffCore/Config/Options.cs index 78c66282..44761572 100644 --- a/SnaffCore/Config/Options.cs +++ b/SnaffCore/Config/Options.cs @@ -14,6 +14,11 @@ public partial class Options { public static Options MyOptions { get; set; } + // Pause and resume functionality + public string TaskFile { get; set; } + public double TaskFileTimeOut { get; set; } = 5; + public string ResumeFrom { get; set; } + // Manual Targeting Options public List PathTargets { get; set; } = new List(); public string[] ComputerTargets { get; set; } diff --git a/SnaffCore/ShareFind/ShareFinder.cs b/SnaffCore/ShareFind/ShareFinder.cs index 4d1395c1..6bcb9543 100644 --- a/SnaffCore/ShareFind/ShareFinder.cs +++ b/SnaffCore/ShareFind/ShareFinder.cs @@ -17,7 +17,7 @@ namespace SnaffCore.ShareFind public class ShareFinder { private BlockingMq Mq { get; set; } - private BlockingStaticTaskScheduler TreeTaskScheduler { get; set; } + private TreeTaskScheduler TreeTaskScheduler { get; set; } private TreeWalker TreeWalker { get; set; } //private EffectivePermissions effectivePermissions { get; set; } = new EffectivePermissions(MyOptions.CurrentUser); @@ -184,18 +184,7 @@ internal void GetComputerShares(string computer) if (MyOptions.ScanFoundShares) { Mq.Trace("Creating a TreeWalker task for " + shareResult.SharePath); - TreeTaskScheduler.New(() => - { - try - { - TreeWalker.WalkTree(shareResult.SharePath); - } - catch (Exception e) - { - Mq.Error("Exception in TreeWalker task for share " + shareResult.SharePath); - Mq.Error(e.ToString()); - } - }); + TreeTaskScheduler.New(TreeWalker.WalkTree, shareResult.SharePath); } Mq.ShareResult(shareResult); } diff --git a/SnaffCore/SnaffCon.cs b/SnaffCore/SnaffCon.cs index 6caf1ab4..9094454d 100644 --- a/SnaffCore/SnaffCon.cs +++ b/SnaffCore/SnaffCon.cs @@ -16,6 +16,8 @@ using static SnaffCore.Config.Options; using Timer = System.Timers.Timer; using System.Net; +using System.IO; +using Nett; namespace SnaffCore { @@ -25,10 +27,10 @@ public class SnaffCon private BlockingMq Mq { get; set; } - private static BlockingStaticTaskScheduler ShareTaskScheduler; - private static BlockingStaticTaskScheduler TreeTaskScheduler; - private static BlockingStaticTaskScheduler FileTaskScheduler; - + private static ShareTaskScheduler ShareTaskScheduler; + private static TreeTaskScheduler TreeTaskScheduler; + private static FileTaskScheduler FileTaskScheduler; + private static ShareFinder ShareFinder; private static TreeWalker TreeWalker; private static FileScanner FileScanner; @@ -46,10 +48,10 @@ public SnaffCon(Options options) int treeThreads = MyOptions.TreeThreads; int fileThreads = MyOptions.FileThreads; - ShareTaskScheduler = new BlockingStaticTaskScheduler(shareThreads, MyOptions.MaxShareQueue); - TreeTaskScheduler = new BlockingStaticTaskScheduler(treeThreads, MyOptions.MaxTreeQueue); - FileTaskScheduler = new BlockingStaticTaskScheduler(fileThreads, MyOptions.MaxFileQueue); - + ShareTaskScheduler = new ShareTaskScheduler(shareThreads, MyOptions.MaxShareQueue); + TreeTaskScheduler = new TreeTaskScheduler(treeThreads, MyOptions.MaxTreeQueue); + FileTaskScheduler = new FileTaskScheduler(fileThreads, MyOptions.MaxFileQueue); + FileScanner = new FileScanner(); TreeWalker = new TreeWalker(); ShareFinder = new ShareFinder(); @@ -67,15 +69,15 @@ public static FileScanner GetFileScanner() { return FileScanner; } - public static BlockingStaticTaskScheduler GetShareTaskScheduler() + public static ShareTaskScheduler GetShareTaskScheduler() { return ShareTaskScheduler; } - public static BlockingStaticTaskScheduler GetTreeTaskScheduler() + public static TreeTaskScheduler GetTreeTaskScheduler() { return TreeTaskScheduler; } - public static BlockingStaticTaskScheduler GetFileTaskScheduler() + public static FileTaskScheduler GetFileTaskScheduler() { return FileTaskScheduler; } @@ -92,71 +94,122 @@ public void Execute() statusUpdateTimer.Start(); - // If we want to hunt for user IDs, we need data from the running user's domain. - // Future - walk trusts - if ( MyOptions.DomainUserRules) + if (MyOptions.TaskFile != null) { - DomainUserDiscovery(); + Timer saveStateTimer = new Timer(TimeSpan.FromMinutes(MyOptions.TaskFileTimeOut).TotalMilliseconds) { AutoReset = true }; + saveStateTimer.Elapsed += ResumingTaskScheduler.SaveState; + saveStateTimer.Start(); } - // Explicit folder setting overrides DFS - if (MyOptions.PathTargets.Count != 0 && (MyOptions.DfsShareDiscovery || MyOptions.DfsOnly)) + if (MyOptions.ResumeFrom != null) { - DomainDfsDiscovery(); - } + // Read the task file and parse it into tuples + Tuple[] taskFileEntries = File.ReadLines(MyOptions.ResumeFrom).Select(line => + { + string[] parts = line.Split('|'); + return parts.Length == 2 ? new Tuple(parts[0], parts[1]) : null; + }).Where(tuple => tuple != null).ToArray(); + + // Get all shares, they are non-recursive so just save all of them + Tuple[] shareEntries = taskFileEntries.Where(entry => entry.Item1 == "share").ToArray(); + + // Remove all entries where the path starts with a pending share + taskFileEntries = taskFileEntries.Where(entry => !shareEntries.Any(shareEntry => entry.Item2.StartsWith("\\\\" + shareEntry.Item2 + "\\"))).ToArray(); + + // Get all tree entires, they are recursive so we need to find the shortest path for each shared base + Tuple[] treeEntries = taskFileEntries.Where(entry => entry.Item1 == "tree" && !taskFileEntries.Any(otherEntry => entry.Item2.StartsWith(otherEntry.Item2 + "\\"))).ToArray(); - if (MyOptions.PathTargets.Count == 0 && MyOptions.ComputerTargets == null) + // Remove all entries where the path starts with one of the base pending tree + taskFileEntries = taskFileEntries.Where(entry => !treeEntries.Any(treeEntry => entry.Item2.StartsWith(treeEntry.Item2 + "\\"))).ToArray(); + + // Tasks should be deduplicated now, dispatch what is left pending + foreach (Tuple entry in taskFileEntries) + { + switch (entry.Item1) + { + case "share": + ShareFinder shareFinder = new ShareFinder(); + ShareTaskScheduler.New(shareFinder.GetComputerShares, entry.Item2); + break; + case "tree": + TreeTaskScheduler.New(TreeWalker.WalkTree, entry.Item2); + break; + case "file": + FileTaskScheduler.New(FileScanner.ScanFile, entry.Item2); + break; + } + } + } + else { - if (MyOptions.DfsSharesDict.Count == 0) + // If we want to hunt for user IDs, we need data from the running user's domain. + // Future - walk trusts + if (MyOptions.DomainUserRules) { - Mq.Info("Invoking DFS Discovery because no ComputerTargets or PathTargets were specified"); - DomainDfsDiscovery(); + DomainUserDiscovery(); } - if (!MyOptions.DfsOnly) + // Explicit folder setting overrides DFS + if (MyOptions.PathTargets.Count != 0 && (MyOptions.DfsShareDiscovery || MyOptions.DfsOnly)) { - Mq.Info("Invoking full domain computer discovery."); - DomainTargetDiscovery(); + DomainDfsDiscovery(); } - else + + if (MyOptions.PathTargets.Count == 0 && MyOptions.ComputerTargets == null) { - Mq.Info("Skipping domain computer discovery."); - foreach (string share in MyOptions.DfsSharesDict.Keys) + if (MyOptions.DfsSharesDict.Count == 0) { - if (!MyOptions.PathTargets.Contains(share)) + Mq.Info("Invoking DFS Discovery because no ComputerTargets or PathTargets were specified"); + DomainDfsDiscovery(); + } + + if (!MyOptions.DfsOnly) + { + Mq.Info("Invoking full domain computer discovery."); + DomainTargetDiscovery(); + } + else + { + Mq.Info("Skipping domain computer discovery."); + foreach (string share in MyOptions.DfsSharesDict.Keys) { - MyOptions.PathTargets.Add(share); + if (!MyOptions.PathTargets.Contains(share)) + { + MyOptions.PathTargets.Add(share); + } } + Mq.Info("Starting TreeWalker tasks on DFS shares."); + FileDiscovery(MyOptions.PathTargets.ToArray()); } - Mq.Info("Starting TreeWalker tasks on DFS shares."); + } + // otherwise we should have a set of path targets... + else if (MyOptions.PathTargets.Count != 0) + { FileDiscovery(MyOptions.PathTargets.ToArray()); } - } - // otherwise we should have a set of path targets... - else if (MyOptions.PathTargets.Count != 0) - { - FileDiscovery(MyOptions.PathTargets.ToArray()); - } - // or we've been told what computers to hit... - else if (MyOptions.ComputerTargets != null) - { - ShareDiscovery(MyOptions.ComputerTargets); - } + // or we've been told what computers to hit... + else if (MyOptions.ComputerTargets != null) + { + ShareDiscovery(MyOptions.ComputerTargets); + } - // but if that hasn't been done, something has gone wrong. - else - { - Mq.Error("OctoParrot says: AWK! I SHOULDN'T BE!"); + // but if that hasn't been done, something has gone wrong. + else + { + Mq.Error("OctoParrot says: AWK! I SHOULDN'T BE!"); + } } waitHandle.WaitOne(); StatusUpdate(); + DateTime finished = DateTime.Now; TimeSpan runSpan = finished.Subtract(StartTime); Mq.Info("Finished at " + finished.ToLocalTime()); Mq.Info("Snafflin' took " + runSpan); Mq.Finish(); + } private void DomainDfsDiscovery() @@ -297,19 +350,8 @@ private void ShareDiscovery(string[] computerTargets) } // ShareFinder Task Creation - this kicks off the rest of the flow Mq.Trace("Creating a ShareFinder task for " + computer); - ShareTaskScheduler.New(() => - { - try - { - ShareFinder shareFinder = new ShareFinder(); - shareFinder.GetComputerShares(computer); - } - catch (Exception e) - { - Mq.Error("Exception in ShareFinder task for host " + computer); - Mq.Error(e.ToString()); - } - }); + ShareFinder shareFinder = new ShareFinder(); + ShareTaskScheduler.New(shareFinder.GetComputerShares, computer); } Mq.Info("Created all sharefinder tasks."); } @@ -371,18 +413,7 @@ private void FileDiscovery(string[] pathTargets) { // TreeWalker Task Creation - this kicks off the rest of the flow Mq.Info("Creating a TreeWalker task for " + pathTarget); - TreeTaskScheduler.New(() => - { - try - { - TreeWalker.WalkTree(pathTarget); - } - catch (Exception e) - { - Mq.Error("Exception in TreeWalker task for path " + pathTarget); - Mq.Error(e.ToString()); - } - }); + TreeTaskScheduler.New(TreeWalker.WalkTree, pathTarget); } Mq.Info("Created all TreeWalker tasks."); diff --git a/SnaffCore/TreeWalk/TreeWalker.cs b/SnaffCore/TreeWalk/TreeWalker.cs index c3cf3fc3..af6f89d5 100644 --- a/SnaffCore/TreeWalk/TreeWalker.cs +++ b/SnaffCore/TreeWalk/TreeWalker.cs @@ -10,8 +10,8 @@ namespace SnaffCore.TreeWalk public class TreeWalker { private BlockingMq Mq { get; set; } - private BlockingStaticTaskScheduler FileTaskScheduler { get; set; } - private BlockingStaticTaskScheduler TreeTaskScheduler { get; set; } + private FileTaskScheduler FileTaskScheduler { get; set; } + private TreeTaskScheduler TreeTaskScheduler { get; set; } private FileScanner FileScanner { get; set; } public TreeWalker() @@ -38,18 +38,7 @@ public void WalkTree(string currentDir) // check if we actually like the files foreach (string file in files) { - FileTaskScheduler.New(() => - { - try - { - FileScanner.ScanFile(file); - } - catch (Exception e) - { - Mq.Error("Exception in FileScanner task for file " + file); - Mq.Trace(e.ToString()); - } - }); + FileTaskScheduler.New(FileScanner.ScanFile, file); } } catch (UnauthorizedAccessException) @@ -104,18 +93,7 @@ public void WalkTree(string currentDir) } if (scanDir == true) { - TreeTaskScheduler.New(() => - { - try - { - WalkTree(dirStr); - } - catch (Exception e) - { - Mq.Error("Exception in TreeWalker task for dir " + dirStr); - Mq.Error(e.ToString()); - } - }); + TreeTaskScheduler.New(WalkTree, dirStr); } else { diff --git a/Snaffler/Config.cs b/Snaffler/Config.cs index 0a7e9a57..ce66e366 100644 --- a/Snaffler/Config.cs +++ b/Snaffler/Config.cs @@ -106,6 +106,9 @@ private static Options ParseImpl(string[] args) ValueArgument logType = new ValueArgument('t', "logtype", "Type of log you would like to output. Currently supported options are plain and JSON. Defaults to plain."); ValueArgument timeOutArg = new ValueArgument('e', "timeout", "Interval between status updates (in minutes) also acts as a timeout for AD data to be gathered via LDAP. Turn this knob up if you aren't getting any computers from AD when you run Snaffler through a proxy or other slow link. Default = 5"); + ValueArgument taskFile = new ValueArgument('1', "taskfile", "Save tasks as they are created to a file to allow for resuming mid-operation."); + ValueArgument resumeFrom = new ValueArgument('2', "resumefrom", "Resume tasks from a file generated with --taskfile."); + ValueArgument taskFileTimeOut = new ValueArgument('3', "taskfiletimeout", "Interval between saving tasks to the task file (in minutes). Default = 5"); // list of letters i haven't used yet: gnqw CommandLineParser.CommandLineParser parser = new CommandLineParser.CommandLineParser(); @@ -132,6 +135,9 @@ private static Options ParseImpl(string[] args) parser.Arguments.Add(ruleDirArg); parser.Arguments.Add(logType); parser.Arguments.Add(compExclusionArg); + parser.Arguments.Add(taskFile); + parser.Arguments.Add(taskFileTimeOut); + parser.Arguments.Add(resumeFrom); // extra check to handle builtin behaviour from cmd line arg parser if ((args.Contains("--help") || args.Contains("/?") || args.Contains("help") || args.Contains("-h") || args.Length == 0)) @@ -150,6 +156,30 @@ private static Options ParseImpl(string[] args) { parser.ParseCommandLine(args); + if (taskFile.Parsed) + { + parsedConfig.TaskFile = taskFile.Value; + } + + if (taskFileTimeOut.Parsed && !String.IsNullOrWhiteSpace(taskFileTimeOut.Value)) + { + double timeOutVal; + if (double.TryParse(taskFileTimeOut.Value, out timeOutVal)) + { + Mq.Info("Set task file saving interval to " + timeOutVal.ToString() + " minutes."); + parsedConfig.TaskFileTimeOut = timeOutVal; + } + else + { + Mq.Error("Invalid task file timeout value passed, defaulting to 5 mins."); + } + } + + if (resumeFrom.Parsed) + { + parsedConfig.ResumeFrom = resumeFrom.Value; + } + if (timeOutArg.Parsed && !String.IsNullOrWhiteSpace(timeOutArg.Value)) { int timeOutVal;