Skip to content

Commit e4cfba0

Browse files
committed
feat(unnecessary_fold): lint on folds with Add::add/Mul::mul
1 parent 4ce7628 commit e4cfba0

File tree

4 files changed

+385
-92
lines changed

4 files changed

+385
-92
lines changed

clippy_lints/src/methods/unnecessary_fold.rs

Lines changed: 102 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,56 @@
11
use clippy_utils::diagnostics::span_lint_and_sugg;
2-
use clippy_utils::res::{MaybeDef, MaybeResPath, MaybeTypeckRes};
2+
use clippy_utils::res::{MaybeDef, MaybeQPath, MaybeResPath, MaybeTypeckRes};
33
use clippy_utils::source::snippet_with_applicability;
4-
use clippy_utils::{peel_blocks, strip_pat_refs};
4+
use clippy_utils::{DefinedTy, ExprUseNode, expr_use_ctxt, peel_blocks, strip_pat_refs};
55
use rustc_ast::ast;
66
use rustc_data_structures::packed::Pu128;
77
use rustc_errors::Applicability;
88
use rustc_hir as hir;
99
use rustc_hir::PatKind;
10+
use rustc_hir::def::{DefKind, Res};
1011
use rustc_lint::LateContext;
11-
use rustc_middle::ty;
12-
use rustc_span::{Span, sym};
12+
use rustc_middle::ty::{self, Ty};
13+
use rustc_span::{Span, Symbol, sym};
1314

1415
use super::UNNECESSARY_FOLD;
1516

1617
/// Do we need to suggest turbofish when suggesting a replacement method?
1718
/// Changing `fold` to `sum` needs it sometimes when the return type can't be
1819
/// inferred. This checks for some common cases where it can be safely omitted
19-
fn needs_turbofish(cx: &LateContext<'_>, expr: &hir::Expr<'_>) -> bool {
20-
let parent = cx.tcx.parent_hir_node(expr.hir_id);
20+
fn needs_turbofish<'tcx>(cx: &LateContext<'tcx>, expr: &hir::Expr<'tcx>) -> bool {
21+
// whether a type contains free generics (`T`) / or opaques (`impl Trait`)
22+
let is_generic = |ty: Ty<'_>| {
23+
ty.walk().any(|generic| {
24+
generic
25+
.as_type()
26+
.is_some_and(|ty| matches!(ty.kind(), ty::Param(_) | ty::Alias(ty::AliasTyKind::Opaque, _)))
27+
})
28+
};
2129

22-
// some common cases where turbofish isn't needed:
23-
// - assigned to a local variable with a type annotation
24-
if let hir::Node::LetStmt(local) = parent
25-
&& local.ty.is_some()
30+
let use_cx = expr_use_ctxt(cx, expr);
31+
if use_cx.same_ctxt
32+
&& let use_node = use_cx.use_node(cx)
33+
&& let Some(ty) = use_node.defined_ty(cx)
2634
{
27-
return false;
28-
}
35+
// some common cases where turbofish isn't needed:
36+
match (use_node, ty) {
37+
// - assigned to a local variable with a type annotation
38+
(ExprUseNode::LetStmt(_), _) => return false,
2939

30-
// - part of a function call argument, can be inferred from the function signature (provided that
31-
// the parameter is not a generic type parameter)
32-
if let hir::Node::Expr(parent_expr) = parent
33-
&& let hir::ExprKind::Call(recv, args) = parent_expr.kind
34-
&& let hir::ExprKind::Path(ref qpath) = recv.kind
35-
&& let Some(fn_def_id) = cx.qpath_res(qpath, recv.hir_id).opt_def_id()
36-
&& let fn_sig = cx.tcx.fn_sig(fn_def_id).skip_binder().skip_binder()
37-
&& let Some(arg_pos) = args.iter().position(|arg| arg.hir_id == expr.hir_id)
38-
&& let Some(ty) = fn_sig.inputs().get(arg_pos)
39-
&& !matches!(ty.kind(), ty::Param(_))
40-
{
41-
return false;
40+
// - part of a function call argument, can be inferred from the function signature (provided that the
41+
// parameter is not a generic type parameter)
42+
(ExprUseNode::FnArg(..), DefinedTy::Mir { ty: arg_ty, .. }) if !is_generic(arg_ty.skip_binder()) => {
43+
return false;
44+
},
45+
46+
// - the final expression in the body of a function with a simple return type
47+
(ExprUseNode::Return(_), DefinedTy::Mir { ty: fn_return_ty, .. })
48+
if !is_generic(fn_return_ty.skip_binder()) =>
49+
{
50+
return false;
51+
},
52+
_ => {},
53+
}
4254
}
4355

4456
// if it's neither of those, stay on the safe side and suggest turbofish,
@@ -60,7 +72,7 @@ fn check_fold_with_op(
6072
fold_span: Span,
6173
op: hir::BinOpKind,
6274
replacement: Replacement,
63-
) {
75+
) -> bool {
6476
if let hir::ExprKind::Closure(&hir::Closure { body, .. }) = acc.kind
6577
// Extract the body of the closure passed to fold
6678
&& let closure_body = cx.tcx.hir_body(body)
@@ -93,7 +105,7 @@ fn check_fold_with_op(
93105
r = snippet_with_applicability(cx, right_expr.span, "EXPR", &mut applicability),
94106
)
95107
} else {
96-
format!("{method}{turbofish}()", method = replacement.method_name,)
108+
format!("{method}{turbofish}()", method = replacement.method_name)
97109
};
98110

99111
span_lint_and_sugg(
@@ -105,12 +117,47 @@ fn check_fold_with_op(
105117
sugg,
106118
applicability,
107119
);
120+
return true;
108121
}
122+
false
109123
}
110124

111-
pub(super) fn check(
125+
fn check_fold_with_method(
112126
cx: &LateContext<'_>,
113127
expr: &hir::Expr<'_>,
128+
acc: &hir::Expr<'_>,
129+
fold_span: Span,
130+
method: Symbol,
131+
replacement: Replacement,
132+
) {
133+
// Extract the name of the function passed to `fold`
134+
if let Res::Def(DefKind::AssocFn, fn_did) = acc.res_if_named(cx, method)
135+
// Check if the function belongs to the operator
136+
&& cx.tcx.is_diagnostic_item(method, fn_did)
137+
{
138+
let applicability = Applicability::MachineApplicable;
139+
140+
let turbofish = if replacement.has_generic_return {
141+
format!("::<{}>", cx.typeck_results().expr_ty(expr))
142+
} else {
143+
String::new()
144+
};
145+
146+
span_lint_and_sugg(
147+
cx,
148+
UNNECESSARY_FOLD,
149+
fold_span.with_hi(expr.span.hi()),
150+
"this `.fold` can be written more succinctly using another method",
151+
"try",
152+
format!("{method}{turbofish}()", method = replacement.method_name),
153+
applicability,
154+
);
155+
}
156+
}
157+
158+
pub(super) fn check<'tcx>(
159+
cx: &LateContext<'tcx>,
160+
expr: &hir::Expr<'tcx>,
114161
init: &hir::Expr<'_>,
115162
acc: &hir::Expr<'_>,
116163
fold_span: Span,
@@ -124,60 +171,40 @@ pub(super) fn check(
124171
if let hir::ExprKind::Lit(lit) = init.kind {
125172
match lit.node {
126173
ast::LitKind::Bool(false) => {
127-
check_fold_with_op(
128-
cx,
129-
expr,
130-
acc,
131-
fold_span,
132-
hir::BinOpKind::Or,
133-
Replacement {
134-
method_name: "any",
135-
has_args: true,
136-
has_generic_return: false,
137-
},
138-
);
174+
let replacement = Replacement {
175+
method_name: "any",
176+
has_args: true,
177+
has_generic_return: false,
178+
};
179+
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Or, replacement);
139180
},
140181
ast::LitKind::Bool(true) => {
141-
check_fold_with_op(
142-
cx,
143-
expr,
144-
acc,
145-
fold_span,
146-
hir::BinOpKind::And,
147-
Replacement {
148-
method_name: "all",
149-
has_args: true,
150-
has_generic_return: false,
151-
},
152-
);
182+
let replacement = Replacement {
183+
method_name: "all",
184+
has_args: true,
185+
has_generic_return: false,
186+
};
187+
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::And, replacement);
153188
},
154189
ast::LitKind::Int(Pu128(0), _) => {
155-
check_fold_with_op(
156-
cx,
157-
expr,
158-
acc,
159-
fold_span,
160-
hir::BinOpKind::Add,
161-
Replacement {
162-
method_name: "sum",
163-
has_args: false,
164-
has_generic_return: needs_turbofish(cx, expr),
165-
},
166-
);
190+
let replacement = Replacement {
191+
method_name: "sum",
192+
has_args: false,
193+
has_generic_return: needs_turbofish(cx, expr),
194+
};
195+
if !check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Add, replacement) {
196+
check_fold_with_method(cx, expr, acc, fold_span, sym::add, replacement);
197+
}
167198
},
168199
ast::LitKind::Int(Pu128(1), _) => {
169-
check_fold_with_op(
170-
cx,
171-
expr,
172-
acc,
173-
fold_span,
174-
hir::BinOpKind::Mul,
175-
Replacement {
176-
method_name: "product",
177-
has_args: false,
178-
has_generic_return: needs_turbofish(cx, expr),
179-
},
180-
);
200+
let replacement = Replacement {
201+
method_name: "product",
202+
has_args: false,
203+
has_generic_return: needs_turbofish(cx, expr),
204+
};
205+
if !check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Mul, replacement) {
206+
check_fold_with_method(cx, expr, acc, fold_span, sym::mul, replacement);
207+
}
181208
},
182209
_ => (),
183210
}

tests/ui/unnecessary_fold.fixed

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,35 @@ fn is_any(acc: bool, x: usize) -> bool {
66

77
/// Calls which should trigger the `UNNECESSARY_FOLD` lint
88
fn unnecessary_fold() {
9+
use std::ops::{Add, Mul};
10+
911
// Can be replaced by .any
1012
let _ = (0..3).any(|x| x > 2);
1113
//~^ unnecessary_fold
14+
1215
// Can be replaced by .any (checking suggestion)
1316
let _ = (0..3).fold(false, is_any);
1417
//~^ redundant_closure
18+
1519
// Can be replaced by .all
1620
let _ = (0..3).all(|x| x > 2);
1721
//~^ unnecessary_fold
22+
1823
// Can be replaced by .sum
1924
let _: i32 = (0..3).sum();
2025
//~^ unnecessary_fold
26+
let _: i32 = (0..3).sum();
27+
//~^ unnecessary_fold
28+
let _: i32 = (0..3).sum();
29+
//~^ unnecessary_fold
30+
2131
// Can be replaced by .product
2232
let _: i32 = (0..3).product();
2333
//~^ unnecessary_fold
34+
let _: i32 = (0..3).product();
35+
//~^ unnecessary_fold
36+
let _: i32 = (0..3).product();
37+
//~^ unnecessary_fold
2438
}
2539

2640
/// Should trigger the `UNNECESSARY_FOLD` lint, with an error span including exactly `.fold(...)`
@@ -37,6 +51,43 @@ fn unnecessary_fold_should_ignore() {
3751
let _ = (0..3).fold(0, |acc, x| acc * x);
3852
let _ = (0..3).fold(0, |acc, x| 1 + acc + x);
3953

54+
struct Adder;
55+
impl Adder {
56+
fn add(lhs: i32, rhs: i32) -> i32 {
57+
unimplemented!()
58+
}
59+
fn mul(lhs: i32, rhs: i32) -> i32 {
60+
unimplemented!()
61+
}
62+
}
63+
// `add`/`mul` are inherent methods
64+
let _: i32 = (0..3).fold(0, Adder::add);
65+
let _: i32 = (0..3).fold(1, Adder::mul);
66+
67+
trait FakeAdd<Rhs = Self> {
68+
type Output;
69+
fn add(self, other: Rhs) -> Self::Output;
70+
}
71+
impl FakeAdd for i32 {
72+
type Output = Self;
73+
fn add(self, other: i32) -> Self::Output {
74+
self + other
75+
}
76+
}
77+
trait FakeMul<Rhs = Self> {
78+
type Output;
79+
fn mul(self, other: Rhs) -> Self::Output;
80+
}
81+
impl FakeMul for i32 {
82+
type Output = Self;
83+
fn mul(self, other: i32) -> Self::Output {
84+
self * other
85+
}
86+
}
87+
// `add`/`mul` come from an unrelated trait
88+
let _: i32 = (0..3).fold(0, FakeAdd::add);
89+
let _: i32 = (0..3).fold(1, FakeMul::mul);
90+
4091
// We only match against an accumulator on the left
4192
// hand side. We could lint for .sum and .product when
4293
// it's on the right, but don't for now (and this wouldn't
@@ -63,6 +114,7 @@ fn unnecessary_fold_over_multiple_lines() {
63114
fn issue10000() {
64115
use std::collections::HashMap;
65116
use std::hash::BuildHasher;
117+
use std::ops::{Add, Mul};
66118

67119
fn anything<T>(_: T) {}
68120
fn num(_: i32) {}
@@ -74,23 +126,56 @@ fn issue10000() {
74126
// more cases:
75127
let _ = map.values().sum::<i32>();
76128
//~^ unnecessary_fold
129+
let _ = map.values().sum::<i32>();
130+
//~^ unnecessary_fold
77131
let _ = map.values().product::<i32>();
78132
//~^ unnecessary_fold
133+
let _ = map.values().product::<i32>();
134+
//~^ unnecessary_fold
135+
let _: i32 = map.values().sum();
136+
//~^ unnecessary_fold
79137
let _: i32 = map.values().sum();
80138
//~^ unnecessary_fold
81139
let _: i32 = map.values().product();
82140
//~^ unnecessary_fold
141+
let _: i32 = map.values().product();
142+
//~^ unnecessary_fold
83143
anything(map.values().sum::<i32>());
84144
//~^ unnecessary_fold
145+
anything(map.values().sum::<i32>());
146+
//~^ unnecessary_fold
147+
anything(map.values().product::<i32>());
148+
//~^ unnecessary_fold
85149
anything(map.values().product::<i32>());
86150
//~^ unnecessary_fold
87151
num(map.values().sum());
88152
//~^ unnecessary_fold
153+
num(map.values().sum());
154+
//~^ unnecessary_fold
155+
num(map.values().product());
156+
//~^ unnecessary_fold
89157
num(map.values().product());
90158
//~^ unnecessary_fold
91159
}
92160

93161
smoketest_map(HashMap::new());
162+
163+
fn add_turbofish_not_necessary() -> i32 {
164+
(0..3).sum()
165+
//~^ unnecessary_fold
166+
}
167+
fn mul_turbofish_not_necessary() -> i32 {
168+
(0..3).product()
169+
//~^ unnecessary_fold
170+
}
171+
fn add_turbofish_necessary() -> impl Add {
172+
(0..3).sum::<i32>()
173+
//~^ unnecessary_fold
174+
}
175+
fn mul_turbofish_necessary() -> impl Mul {
176+
(0..3).product::<i32>()
177+
//~^ unnecessary_fold
178+
}
94179
}
95180

96181
fn main() {}

0 commit comments

Comments
 (0)