Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 315 additions & 0 deletions src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@ private enum CountCheckStatus
HasCount,
}

private enum LinqPredicateCheckStatus
{
Unknown,
Any,
Count,
WhereAny,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding SingleAny and SingleCount for the .Single() Linq extension method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just to add it as part of the enums because, there is no use of it for now in the codebase except if we have an issue that wants the analyzer to check for predicate that uses Single Linq

@Youssef1313 thoughts on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant full support of .Single() in this analyzer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh okay.. please can you provide sample code of how it can be used and the expected suggestion code we want analyzer to give..

just to make sure we are on the same page. thank you.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see a few cases, but I'm not sure if they make sense. In all those cases the thrown exception would change at least in one of the possible failure cases. (And those old assertions are not good to start with.)
That's also why I didn't include those when I initially wrote the issue.

// Expected analyzer suggestion:
//Assert.ContainsSingle(x => x == 1, _enumerable);

// Questionable:
Assert.IsNotNull(_enumerable.Where(x => x == 1).SingleOrDefault());
Assert.IsNotNull(_enumerable.SingleOrDefault(x => x == 1));

// Even more questionable:
Assert.IsNotNull(_enumerable.Where(x => x == 1).Single());
Assert.IsNotNull(_enumerable.Single(x => x == 1));
// Expected analyzer suggestion:
//Assert.DoesNotContain(x => x == 1, _enumerable);

// Questionable:
Assert.IsNull(_enumerable.Where(x => x == 1).SingleOrDefault());
Assert.IsNull(_enumerable.SingleOrDefault(x => x == 1));

Copy link
Contributor Author

@AtolagbeMuiz AtolagbeMuiz Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pull request is now updated with the requested changes.. analyzer has been updated for Single and SingleOrDefault @stan-sz @cremor

WhereCount,
Single,
SingleOrDefault,
WhereSingle,
WhereSingleOrDefault,
}

internal const string ProperAssertMethodNameKey = nameof(ProperAssertMethodNameKey);

/// <summary>
Expand Down Expand Up @@ -268,6 +281,56 @@ private static void AnalyzeInvocationOperation(OperationAnalysisContext context,
case "AreNotEqual":
AnalyzeAreEqualOrAreNotEqualInvocation(context, firstArgument, isAreEqualInvocation: false, objectTypeSymbol);
break;
case "IsNull":
AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: true);
break;

case "IsNotNull":
AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: false);
break;
}
}

private static void AnalyzeIsNullOrIsNotNullInvocation(OperationAnalysisContext context, IOperation argument, bool isNullCheck)
{
RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation.");

// Check for Single/SingleOrDefault patterns
LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck(
argument,
out SyntaxNode? linqCollectionExpr,
out SyntaxNode? predicateExpr,
out _);

if (linqStatus is LinqPredicateCheckStatus.Single or
LinqPredicateCheckStatus.SingleOrDefault or
LinqPredicateCheckStatus.WhereSingle or
LinqPredicateCheckStatus.WhereSingleOrDefault &&
linqCollectionExpr != null)
{
// For Assert.IsNotNull(enumerable.Single[OrDefault](...)) -> Assert.ContainsSingle
// For Assert.IsNull(enumerable.Single[OrDefault](...)) -> Assert.DoesNotContain
string properAssertMethod = isNullCheck ? "DoesNotContain" : "ContainsSingle";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, predicateExpr != null ? CodeFixModeAddArgument : CodeFixModeSimple);

ImmutableArray<Location> additionalLocations = predicateExpr != null
? ImmutableArray.Create(
argument.Syntax.GetLocation(),
predicateExpr.GetLocation(),
linqCollectionExpr.GetLocation())
: ImmutableArray.Create(
argument.Syntax.GetLocation(),
linqCollectionExpr.GetLocation());

context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: additionalLocations,
properties: properties.ToImmutable(),
properAssertMethod,
isNullCheck ? "IsNull" : "IsNotNull"));
}
}

Expand Down Expand Up @@ -519,6 +582,146 @@ private static ComparisonCheckStatus RecognizeComparisonCheck(
return ComparisonCheckStatus.Unknown;
}

private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck(
IOperation operation,
out SyntaxNode? collectionExpression,
out SyntaxNode? predicateExpression,
out IOperation? countOperation)
{
collectionExpression = null;
predicateExpression = null;
countOperation = null;

// Check for enumerable.Any(predicate)
// Extension methods appear as: Instance=null, Arguments[0]=collection, Arguments[1]=predicate
if (operation is IInvocationOperation anyInvocation &&
anyInvocation.TargetMethod.Name == "Any" &&
anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
anyInvocation.Arguments.Length == 2)
{
collectionExpression = anyInvocation.Arguments[0].Value.Syntax;
predicateExpression = anyInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.Any;
}

// Check for enumerable.Count(predicate)
if (operation is IInvocationOperation countInvocation &&
countInvocation.TargetMethod.Name == "Count" &&
countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
countInvocation.Arguments.Length == 2)
{
collectionExpression = countInvocation.Arguments[0].Value.Syntax;
predicateExpression = countInvocation.Arguments[1].Value.Syntax;
countOperation = operation;
return LinqPredicateCheckStatus.Count;
}

// Check for enumerable.Where(predicate).Any()
if (operation is IInvocationOperation whereAnyInvocation &&
whereAnyInvocation.TargetMethod.Name == "Any" &&
whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereAnyInvocation.Arguments.Length == 1 &&
whereAnyInvocation.Arguments[0].Value is IInvocationOperation whereInvocation &&
whereInvocation.TargetMethod.Name == "Where" &&
whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation.Arguments.Length == 2)
{
collectionExpression = whereInvocation.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereAny;
}

// Check for enumerable.Where(predicate).Count()
if (operation is IInvocationOperation whereCountInvocation &&
whereCountInvocation.TargetMethod.Name == "Count" &&
whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereCountInvocation.Arguments.Length == 1 &&
whereCountInvocation.Arguments[0].Value is IInvocationOperation whereInvocation2 &&
whereInvocation2.TargetMethod.Name == "Where" &&
whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation2.Arguments.Length == 2)
{
collectionExpression = whereInvocation2.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation2.Arguments[1].Value.Syntax;
countOperation = operation;
return LinqPredicateCheckStatus.WhereCount;
}

// Check for enumerable.Where(predicate).Single()
if (operation is IInvocationOperation whereSingleInvocation &&
whereSingleInvocation.TargetMethod.Name == "Single" &&
whereSingleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereSingleInvocation.Arguments.Length == 1 &&
whereSingleInvocation.Arguments[0].Value is IInvocationOperation whereInvocation3 &&
whereInvocation3.TargetMethod.Name == "Where" &&
whereInvocation3.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation3.Arguments.Length == 2)
{
collectionExpression = whereInvocation3.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation3.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereSingle;
}

// Check for enumerable.Where(predicate).SingleOrDefault()
if (operation is IInvocationOperation whereSingleOrDefaultInvocation &&
whereSingleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" &&
whereSingleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereSingleOrDefaultInvocation.Arguments.Length == 1 &&
whereSingleOrDefaultInvocation.Arguments[0].Value is IInvocationOperation whereInvocation4 &&
whereInvocation4.TargetMethod.Name == "Where" &&
whereInvocation4.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation4.Arguments.Length == 2)
{
collectionExpression = whereInvocation4.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation4.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereSingleOrDefault;
}

// Check for enumerable.Single(predicate)
if (operation is IInvocationOperation singleInvocation &&
singleInvocation.TargetMethod.Name == "Single" &&
singleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable")
{
if (singleInvocation.Arguments.Length == 2)
{
// Extension method with predicate
collectionExpression = singleInvocation.Arguments[0].Value.Syntax;
predicateExpression = singleInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.Single;
}
else if (singleInvocation.Arguments.Length == 1)
{
// Instance method or extension without predicate
collectionExpression = singleInvocation.Instance?.Syntax ?? singleInvocation.Arguments[0].Value.Syntax;
predicateExpression = null;
return LinqPredicateCheckStatus.Single;
}
}

// Check for enumerable.SingleOrDefault(predicate)
if (operation is IInvocationOperation singleOrDefaultInvocation &&
singleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" &&
singleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable")
{
if (singleOrDefaultInvocation.Arguments.Length == 2)
{
// Extension method with predicate
collectionExpression = singleOrDefaultInvocation.Arguments[0].Value.Syntax;
predicateExpression = singleOrDefaultInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.SingleOrDefault;
}
else if (singleOrDefaultInvocation.Arguments.Length == 1)
{
// Instance method or extension without predicate
collectionExpression = singleOrDefaultInvocation.Instance?.Syntax ?? singleOrDefaultInvocation.Arguments[0].Value.Syntax;
predicateExpression = null;
return LinqPredicateCheckStatus.SingleOrDefault;
}
}

return LinqPredicateCheckStatus.Unknown;
}

private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext context, IOperation conditionArgument, bool isTrueInvocation, INamedTypeSymbol objectTypeSymbol)
{
RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation.");
Expand Down Expand Up @@ -555,6 +758,36 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

// Check for LINQ predicate patterns that suggest Contains/DoesNotContain
LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck(
conditionArgument,
out SyntaxNode? linqCollectionExpr,
out SyntaxNode? predicateExpr,
out _);

if (linqStatus != LinqPredicateCheckStatus.Unknown && linqCollectionExpr != null && predicateExpr != null)
{
// For Any() and Where().Any() patterns
if (linqStatus is LinqPredicateCheckStatus.Any or LinqPredicateCheckStatus.WhereAny)
{
string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);
context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
conditionArgument.Syntax.GetLocation(),
predicateExpr.GetLocation(),
linqCollectionExpr.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
isTrueInvocation ? "IsTrue" : "IsFalse"));
return;
}
}

// Check for string method patterns: myString.StartsWith/EndsWith/Contains(...)
StringMethodCheckStatus stringMethodStatus = RecognizeStringMethodCheck(conditionArgument, out SyntaxNode? stringExpr, out SyntaxNode? substringExpr);
if (stringMethodStatus != StringMethodCheckStatus.Unknown)
Expand Down Expand Up @@ -624,6 +857,54 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

// Special-case: enumerable.Count(predicate) > 0 → Assert.Contains(predicate, enumerable)
if (conditionArgument is IBinaryOperation binaryOp &&
binaryOp.OperatorKind == BinaryOperatorKind.GreaterThan)
{
if (binaryOp.LeftOperand is IInvocationOperation countInvocation &&
binaryOp.RightOperand.ConstantValue.HasValue &&
binaryOp.RightOperand.ConstantValue.Value is int intValue &&
intValue == 0 &&
countInvocation.TargetMethod.Name == "Count")
{
SyntaxNode? countCollectionExpr = null;
SyntaxNode? countPredicateExpr = null;

if (countInvocation.Instance != null && countInvocation.Arguments.Length == 1)
{
countCollectionExpr = countInvocation.Instance.Syntax;
countPredicateExpr = countInvocation.Arguments[0].Value.Syntax;
}
else if (countInvocation.Instance == null && countInvocation.Arguments.Length == 2)
{
countCollectionExpr = countInvocation.Arguments[0].Value.Syntax;
countPredicateExpr = countInvocation.Arguments[1].Value.Syntax;
}

if (countCollectionExpr != null && countPredicateExpr != null)
{
string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);

context.ReportDiagnostic(
context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
conditionArgument.Syntax.GetLocation(),
countPredicateExpr.GetLocation(),
countCollectionExpr.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
isTrueInvocation ? "IsTrue" : "IsFalse"));

return;
}
}
}

// Check for comparison patterns: a > b, a >= b, a < b, a <= b
ComparisonCheckStatus comparisonStatus = RecognizeComparisonCheck(conditionArgument, out SyntaxNode? leftExpr, out SyntaxNode? rightExpr);
if (comparisonStatus != ComparisonCheckStatus.Unknown)
Expand Down Expand Up @@ -722,6 +1003,40 @@ private static void AnalyzeAreEqualOrAreNotEqualInvocation(OperationAnalysisCont
{
if (TryGetSecondArgumentValue((IInvocationOperation)context.Operation, out IOperation? actualArgumentValue))
{
// Check for LINQ predicate patterns that suggest ContainsSingle
LinqPredicateCheckStatus linqStatus2 = RecognizeLinqPredicateCheck(
actualArgumentValue!,
out SyntaxNode? linqCollectionExpr2,
out SyntaxNode? predicateExpr2,
out _);

if (isAreEqualInvocation &&
linqStatus2 is LinqPredicateCheckStatus.Count or LinqPredicateCheckStatus.WhereCount &&
linqCollectionExpr2 != null &&
predicateExpr2 != null &&
expectedArgument.ConstantValue.HasValue &&
expectedArgument.ConstantValue.Value is int expectedCountValue &&
expectedCountValue == 1)
{
// We have Assert.AreEqual(1, enumerable.Count(predicate))
// We want Assert.ContainsSingle(predicate, enumerable)
string properAssertMethod = "ContainsSingle";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);
context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
actualArgumentValue.Syntax.GetLocation(),
predicateExpr2.GetLocation(),
linqCollectionExpr2.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
"AreEqual"));
return;
}

// Check if we're comparing a count/length property
CountCheckStatus countStatus = RecognizeCountCheck(
expectedArgument,
Expand Down
3 changes: 2 additions & 1 deletion test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ private static void ValidateOutputIsNotMixed(IEnumerable<TestResult> testResults
Assert.Contains(methodName, message.Text);
Assert.Contains("TestInitialize", message.Text);
Assert.Contains("TestCleanup", message.Text);
Assert.IsFalse(shouldNotContain.Any(message.Text.Contains));
// Assert.IsFalse(shouldNotContain.Any(message.Text.Contains));
Assert.DoesNotContain(message.Text.Contains, shouldNotContain);
}

private static void ValidateInitializeAndCleanup(IEnumerable<TestResult> testResults, Func<TestResultMessage, bool> messageFilter)
Expand Down
Loading