diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index 45f4e90b07..78dae8a19f 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -125,6 +125,19 @@ private enum CountCheckStatus HasCount, } + private enum LinqPredicateCheckStatus + { + Unknown, + Any, + Count, + WhereAny, + WhereCount, + Single, + SingleOrDefault, + WhereSingle, + WhereSingleOrDefault, + } + internal const string ProperAssertMethodNameKey = nameof(ProperAssertMethodNameKey); /// @@ -268,6 +281,56 @@ private static void AnalyzeInvocationOperation(OperationAnalysisContext context, case "AreNotEqual": AnalyzeAreEqualOrAreNotEqualInvocation(context, firstArgument, isAreEqualInvocation: false, objectTypeSymbol); break; + case "IsNull": + AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: true); + break; + + case "IsNotNull": + AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: false); + break; + } + } + + private static void AnalyzeIsNullOrIsNotNullInvocation(OperationAnalysisContext context, IOperation argument, bool isNullCheck) + { + RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation."); + + // Check for Single/SingleOrDefault patterns + LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck( + argument, + out SyntaxNode? linqCollectionExpr, + out SyntaxNode? predicateExpr, + out _); + + if (linqStatus is LinqPredicateCheckStatus.Single or + LinqPredicateCheckStatus.SingleOrDefault or + LinqPredicateCheckStatus.WhereSingle or + LinqPredicateCheckStatus.WhereSingleOrDefault && + linqCollectionExpr != null) + { + // For Assert.IsNotNull(enumerable.Single[OrDefault](...)) -> Assert.ContainsSingle + // For Assert.IsNull(enumerable.Single[OrDefault](...)) -> Assert.DoesNotContain + string properAssertMethod = isNullCheck ? "DoesNotContain" : "ContainsSingle"; + + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, predicateExpr != null ? CodeFixModeAddArgument : CodeFixModeSimple); + + ImmutableArray additionalLocations = predicateExpr != null + ? ImmutableArray.Create( + argument.Syntax.GetLocation(), + predicateExpr.GetLocation(), + linqCollectionExpr.GetLocation()) + : ImmutableArray.Create( + argument.Syntax.GetLocation(), + linqCollectionExpr.GetLocation()); + + context.ReportDiagnostic(context.Operation.CreateDiagnostic( + Rule, + additionalLocations: additionalLocations, + properties: properties.ToImmutable(), + properAssertMethod, + isNullCheck ? "IsNull" : "IsNotNull")); } } @@ -519,6 +582,146 @@ private static ComparisonCheckStatus RecognizeComparisonCheck( return ComparisonCheckStatus.Unknown; } + private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck( + IOperation operation, + out SyntaxNode? collectionExpression, + out SyntaxNode? predicateExpression, + out IOperation? countOperation) + { + collectionExpression = null; + predicateExpression = null; + countOperation = null; + + // Check for enumerable.Any(predicate) + // Extension methods appear as: Instance=null, Arguments[0]=collection, Arguments[1]=predicate + if (operation is IInvocationOperation anyInvocation && + anyInvocation.TargetMethod.Name == "Any" && + anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + anyInvocation.Arguments.Length == 2) + { + collectionExpression = anyInvocation.Arguments[0].Value.Syntax; + predicateExpression = anyInvocation.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.Any; + } + + // Check for enumerable.Count(predicate) + if (operation is IInvocationOperation countInvocation && + countInvocation.TargetMethod.Name == "Count" && + countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + countInvocation.Arguments.Length == 2) + { + collectionExpression = countInvocation.Arguments[0].Value.Syntax; + predicateExpression = countInvocation.Arguments[1].Value.Syntax; + countOperation = operation; + return LinqPredicateCheckStatus.Count; + } + + // Check for enumerable.Where(predicate).Any() + if (operation is IInvocationOperation whereAnyInvocation && + whereAnyInvocation.TargetMethod.Name == "Any" && + whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereAnyInvocation.Arguments.Length == 1 && + whereAnyInvocation.Arguments[0].Value is IInvocationOperation whereInvocation && + whereInvocation.TargetMethod.Name == "Where" && + whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation.Arguments.Length == 2) + { + collectionExpression = whereInvocation.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.WhereAny; + } + + // Check for enumerable.Where(predicate).Count() + if (operation is IInvocationOperation whereCountInvocation && + whereCountInvocation.TargetMethod.Name == "Count" && + whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereCountInvocation.Arguments.Length == 1 && + whereCountInvocation.Arguments[0].Value is IInvocationOperation whereInvocation2 && + whereInvocation2.TargetMethod.Name == "Where" && + whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation2.Arguments.Length == 2) + { + collectionExpression = whereInvocation2.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation2.Arguments[1].Value.Syntax; + countOperation = operation; + return LinqPredicateCheckStatus.WhereCount; + } + + // Check for enumerable.Where(predicate).Single() + if (operation is IInvocationOperation whereSingleInvocation && + whereSingleInvocation.TargetMethod.Name == "Single" && + whereSingleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereSingleInvocation.Arguments.Length == 1 && + whereSingleInvocation.Arguments[0].Value is IInvocationOperation whereInvocation3 && + whereInvocation3.TargetMethod.Name == "Where" && + whereInvocation3.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation3.Arguments.Length == 2) + { + collectionExpression = whereInvocation3.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation3.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.WhereSingle; + } + + // Check for enumerable.Where(predicate).SingleOrDefault() + if (operation is IInvocationOperation whereSingleOrDefaultInvocation && + whereSingleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" && + whereSingleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereSingleOrDefaultInvocation.Arguments.Length == 1 && + whereSingleOrDefaultInvocation.Arguments[0].Value is IInvocationOperation whereInvocation4 && + whereInvocation4.TargetMethod.Name == "Where" && + whereInvocation4.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation4.Arguments.Length == 2) + { + collectionExpression = whereInvocation4.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation4.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.WhereSingleOrDefault; + } + + // Check for enumerable.Single(predicate) + if (operation is IInvocationOperation singleInvocation && + singleInvocation.TargetMethod.Name == "Single" && + singleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + if (singleInvocation.Arguments.Length == 2) + { + // Extension method with predicate + collectionExpression = singleInvocation.Arguments[0].Value.Syntax; + predicateExpression = singleInvocation.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.Single; + } + else if (singleInvocation.Arguments.Length == 1) + { + // Instance method or extension without predicate + collectionExpression = singleInvocation.Instance?.Syntax ?? singleInvocation.Arguments[0].Value.Syntax; + predicateExpression = null; + return LinqPredicateCheckStatus.Single; + } + } + + // Check for enumerable.SingleOrDefault(predicate) + if (operation is IInvocationOperation singleOrDefaultInvocation && + singleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" && + singleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + if (singleOrDefaultInvocation.Arguments.Length == 2) + { + // Extension method with predicate + collectionExpression = singleOrDefaultInvocation.Arguments[0].Value.Syntax; + predicateExpression = singleOrDefaultInvocation.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.SingleOrDefault; + } + else if (singleOrDefaultInvocation.Arguments.Length == 1) + { + // Instance method or extension without predicate + collectionExpression = singleOrDefaultInvocation.Instance?.Syntax ?? singleOrDefaultInvocation.Arguments[0].Value.Syntax; + predicateExpression = null; + return LinqPredicateCheckStatus.SingleOrDefault; + } + } + + return LinqPredicateCheckStatus.Unknown; + } + private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext context, IOperation conditionArgument, bool isTrueInvocation, INamedTypeSymbol objectTypeSymbol) { RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation."); @@ -555,6 +758,36 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co return; } + // Check for LINQ predicate patterns that suggest Contains/DoesNotContain + LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck( + conditionArgument, + out SyntaxNode? linqCollectionExpr, + out SyntaxNode? predicateExpr, + out _); + + if (linqStatus != LinqPredicateCheckStatus.Unknown && linqCollectionExpr != null && predicateExpr != null) + { + // For Any() and Where().Any() patterns + if (linqStatus is LinqPredicateCheckStatus.Any or LinqPredicateCheckStatus.WhereAny) + { + string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain"; + + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, CodeFixModeAddArgument); + context.ReportDiagnostic(context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create( + conditionArgument.Syntax.GetLocation(), + predicateExpr.GetLocation(), + linqCollectionExpr.GetLocation()), + properties: properties.ToImmutable(), + properAssertMethod, + isTrueInvocation ? "IsTrue" : "IsFalse")); + return; + } + } + // Check for string method patterns: myString.StartsWith/EndsWith/Contains(...) StringMethodCheckStatus stringMethodStatus = RecognizeStringMethodCheck(conditionArgument, out SyntaxNode? stringExpr, out SyntaxNode? substringExpr); if (stringMethodStatus != StringMethodCheckStatus.Unknown) @@ -624,6 +857,54 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co return; } + // Special-case: enumerable.Count(predicate) > 0 → Assert.Contains(predicate, enumerable) + if (conditionArgument is IBinaryOperation binaryOp && + binaryOp.OperatorKind == BinaryOperatorKind.GreaterThan) + { + if (binaryOp.LeftOperand is IInvocationOperation countInvocation && + binaryOp.RightOperand.ConstantValue.HasValue && + binaryOp.RightOperand.ConstantValue.Value is int intValue && + intValue == 0 && + countInvocation.TargetMethod.Name == "Count") + { + SyntaxNode? countCollectionExpr = null; + SyntaxNode? countPredicateExpr = null; + + if (countInvocation.Instance != null && countInvocation.Arguments.Length == 1) + { + countCollectionExpr = countInvocation.Instance.Syntax; + countPredicateExpr = countInvocation.Arguments[0].Value.Syntax; + } + else if (countInvocation.Instance == null && countInvocation.Arguments.Length == 2) + { + countCollectionExpr = countInvocation.Arguments[0].Value.Syntax; + countPredicateExpr = countInvocation.Arguments[1].Value.Syntax; + } + + if (countCollectionExpr != null && countPredicateExpr != null) + { + string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain"; + + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, CodeFixModeAddArgument); + + context.ReportDiagnostic( + context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create( + conditionArgument.Syntax.GetLocation(), + countPredicateExpr.GetLocation(), + countCollectionExpr.GetLocation()), + properties: properties.ToImmutable(), + properAssertMethod, + isTrueInvocation ? "IsTrue" : "IsFalse")); + + return; + } + } + } + // Check for comparison patterns: a > b, a >= b, a < b, a <= b ComparisonCheckStatus comparisonStatus = RecognizeComparisonCheck(conditionArgument, out SyntaxNode? leftExpr, out SyntaxNode? rightExpr); if (comparisonStatus != ComparisonCheckStatus.Unknown) @@ -722,6 +1003,40 @@ private static void AnalyzeAreEqualOrAreNotEqualInvocation(OperationAnalysisCont { if (TryGetSecondArgumentValue((IInvocationOperation)context.Operation, out IOperation? actualArgumentValue)) { + // Check for LINQ predicate patterns that suggest ContainsSingle + LinqPredicateCheckStatus linqStatus2 = RecognizeLinqPredicateCheck( + actualArgumentValue!, + out SyntaxNode? linqCollectionExpr2, + out SyntaxNode? predicateExpr2, + out _); + + if (isAreEqualInvocation && + linqStatus2 is LinqPredicateCheckStatus.Count or LinqPredicateCheckStatus.WhereCount && + linqCollectionExpr2 != null && + predicateExpr2 != null && + expectedArgument.ConstantValue.HasValue && + expectedArgument.ConstantValue.Value is int expectedCountValue && + expectedCountValue == 1) + { + // We have Assert.AreEqual(1, enumerable.Count(predicate)) + // We want Assert.ContainsSingle(predicate, enumerable) + string properAssertMethod = "ContainsSingle"; + + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, CodeFixModeAddArgument); + context.ReportDiagnostic(context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create( + actualArgumentValue.Syntax.GetLocation(), + predicateExpr2.GetLocation(), + linqCollectionExpr2.GetLocation()), + properties: properties.ToImmutable(), + properAssertMethod, + "AreEqual")); + return; + } + // Check if we're comparing a count/length property CountCheckStatus countStatus = RecognizeCountCheck( expectedArgument, diff --git a/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs b/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs index 02a68522af..785ae5d304 100644 --- a/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs +++ b/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs @@ -84,7 +84,8 @@ private static void ValidateOutputIsNotMixed(IEnumerable testResults Assert.Contains(methodName, message.Text); Assert.Contains("TestInitialize", message.Text); Assert.Contains("TestCleanup", message.Text); - Assert.IsFalse(shouldNotContain.Any(message.Text.Contains)); + // Assert.IsFalse(shouldNotContain.Any(message.Text.Contains)); + Assert.DoesNotContain(message.Text.Contains, shouldNotContain); } private static void ValidateInitializeAndCleanup(IEnumerable testResults, Func messageFilter) diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index deb75869dd..1d2c7d2e2e 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -2908,4 +2908,566 @@ await VerifyCS.VerifyCodeFixAsync( } #endregion + + #region Predicate Pattern Tests + [TestMethod] + public async Task WhenUsingIsTrueAnyWithPredicate_SuggestsContains() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsTrue(enumerable.Any(x => x == 1))|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsTrueWhereAnyWithPredicate_SuggestsContains() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsTrue(enumerable.Where(x => x == 1).Any())|}; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsFalseWhereAnyWithPredicate_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsFalse(enumerable.Where(x => x == 1).Any())|}; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsFalseWithAny_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsFalse(enumerable.Any(x => x == 1))|}; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsFalseWithWhereAny_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsFalse(enumerable.Where(x => x == 1).Any())|}; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsTrueCountGreaterThanZero_SuggestsContains() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsTrue(enumerable.Count(x => x == 1) > 0)|}; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } + } + """; + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsFalseCountGreaterThanZero_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsFalse(enumerable.Count(x => x == 1) > 0)|}; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + """; + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNotNullSingleOrDefaultWithPredicate_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNotNull(enumerable.SingleOrDefault(x => x == 1))|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "IsNotNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNotNullSingleWithPredicate_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNotNull(enumerable.Single(x => x == 1))|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "IsNotNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNotNullWhereSingleOrDefault_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNotNull(enumerable.Where(x => x == 1).SingleOrDefault())|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "IsNotNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNotNullWhereSingle_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNotNull(enumerable.Where(x => x == 1).Single())|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "IsNotNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNullSingleOrDefaultWithPredicate_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNull(enumerable.SingleOrDefault(x => x == 1))|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNullWhereSingleOrDefault_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNull(enumerable.Where(x => x == 1).SingleOrDefault())|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsNull"), + fixedCode); + } + + #endregion }