Skip to content

Commit ce5a4a1

Browse files
committed
feat(unnecessary_fold): lint on folds with Add::add/Mul::mul
1 parent 4ce7628 commit ce5a4a1

File tree

4 files changed

+381
-92
lines changed

4 files changed

+381
-92
lines changed

clippy_lints/src/methods/unnecessary_fold.rs

Lines changed: 98 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,52 @@
11
use clippy_utils::diagnostics::span_lint_and_sugg;
2-
use clippy_utils::res::{MaybeDef, MaybeResPath, MaybeTypeckRes};
2+
use clippy_utils::res::{MaybeDef, MaybeQPath, MaybeResPath, MaybeTypeckRes};
33
use clippy_utils::source::snippet_with_applicability;
4-
use clippy_utils::{peel_blocks, strip_pat_refs};
4+
use clippy_utils::{DefinedTy, ExprUseNode, expr_use_ctxt, peel_blocks, strip_pat_refs};
55
use rustc_ast::ast;
66
use rustc_data_structures::packed::Pu128;
77
use rustc_errors::Applicability;
88
use rustc_hir as hir;
99
use rustc_hir::PatKind;
10+
use rustc_hir::def::{DefKind, Res};
1011
use rustc_lint::LateContext;
1112
use rustc_middle::ty;
12-
use rustc_span::{Span, sym};
13+
use rustc_span::{Span, Symbol, sym};
1314

1415
use super::UNNECESSARY_FOLD;
1516

1617
/// Do we need to suggest turbofish when suggesting a replacement method?
1718
/// Changing `fold` to `sum` needs it sometimes when the return type can't be
1819
/// inferred. This checks for some common cases where it can be safely omitted
19-
fn needs_turbofish(cx: &LateContext<'_>, expr: &hir::Expr<'_>) -> bool {
20-
let parent = cx.tcx.parent_hir_node(expr.hir_id);
21-
22-
// some common cases where turbofish isn't needed:
23-
// - assigned to a local variable with a type annotation
24-
if let hir::Node::LetStmt(local) = parent
25-
&& local.ty.is_some()
20+
fn needs_turbofish<'tcx>(cx: &LateContext<'tcx>, expr: &hir::Expr<'tcx>) -> bool {
21+
let use_cx = expr_use_ctxt(cx, expr);
22+
if use_cx.same_ctxt
23+
&& let use_node = use_cx.use_node(cx)
24+
&& let Some(ty) = use_node.defined_ty(cx)
2625
{
27-
return false;
28-
}
26+
// some common cases where turbofish isn't needed:
27+
match (use_node, ty) {
28+
// - assigned to a local variable with a type annotation
29+
(ExprUseNode::LetStmt(_), _) => return false,
2930

30-
// - part of a function call argument, can be inferred from the function signature (provided that
31-
// the parameter is not a generic type parameter)
32-
if let hir::Node::Expr(parent_expr) = parent
33-
&& let hir::ExprKind::Call(recv, args) = parent_expr.kind
34-
&& let hir::ExprKind::Path(ref qpath) = recv.kind
35-
&& let Some(fn_def_id) = cx.qpath_res(qpath, recv.hir_id).opt_def_id()
36-
&& let fn_sig = cx.tcx.fn_sig(fn_def_id).skip_binder().skip_binder()
37-
&& let Some(arg_pos) = args.iter().position(|arg| arg.hir_id == expr.hir_id)
38-
&& let Some(ty) = fn_sig.inputs().get(arg_pos)
39-
&& !matches!(ty.kind(), ty::Param(_))
40-
{
41-
return false;
31+
// - part of a function call argument, can be inferred from the function signature (provided that the
32+
// parameter is not a generic type parameter)
33+
(ExprUseNode::FnArg(..), DefinedTy::Mir { ty: arg_ty, .. })
34+
if !matches!(arg_ty.skip_binder().kind(), ty::Param(_)) =>
35+
{
36+
return false;
37+
},
38+
39+
// - the final expression in the body of a function with a simple return type
40+
(ExprUseNode::Return(_), DefinedTy::Mir { ty: fn_return_ty, .. })
41+
if !fn_return_ty
42+
.skip_binder()
43+
.walk()
44+
.any(|generic| generic.as_type().is_some_and(|ty| ty.is_impl_trait())) =>
45+
{
46+
return false;
47+
},
48+
_ => {},
49+
}
4250
}
4351

4452
// if it's neither of those, stay on the safe side and suggest turbofish,
@@ -60,7 +68,7 @@ fn check_fold_with_op(
6068
fold_span: Span,
6169
op: hir::BinOpKind,
6270
replacement: Replacement,
63-
) {
71+
) -> bool {
6472
if let hir::ExprKind::Closure(&hir::Closure { body, .. }) = acc.kind
6573
// Extract the body of the closure passed to fold
6674
&& let closure_body = cx.tcx.hir_body(body)
@@ -93,7 +101,7 @@ fn check_fold_with_op(
93101
r = snippet_with_applicability(cx, right_expr.span, "EXPR", &mut applicability),
94102
)
95103
} else {
96-
format!("{method}{turbofish}()", method = replacement.method_name,)
104+
format!("{method}{turbofish}()", method = replacement.method_name)
97105
};
98106

99107
span_lint_and_sugg(
@@ -105,12 +113,47 @@ fn check_fold_with_op(
105113
sugg,
106114
applicability,
107115
);
116+
return true;
108117
}
118+
false
109119
}
110120

111-
pub(super) fn check(
121+
fn check_fold_with_method(
112122
cx: &LateContext<'_>,
113123
expr: &hir::Expr<'_>,
124+
acc: &hir::Expr<'_>,
125+
fold_span: Span,
126+
method: Symbol,
127+
replacement: Replacement,
128+
) {
129+
// Extract the name of the function passed to `fold`
130+
if let Res::Def(DefKind::AssocFn, fn_did) = acc.res_if_named(cx, method)
131+
// Check if the function belongs to the operator
132+
&& cx.tcx.is_diagnostic_item(method, fn_did)
133+
{
134+
let applicability = Applicability::MachineApplicable;
135+
136+
let turbofish = if replacement.has_generic_return {
137+
format!("::<{}>", cx.typeck_results().expr_ty(expr))
138+
} else {
139+
String::new()
140+
};
141+
142+
span_lint_and_sugg(
143+
cx,
144+
UNNECESSARY_FOLD,
145+
fold_span.with_hi(expr.span.hi()),
146+
"this `.fold` can be written more succinctly using another method",
147+
"try",
148+
format!("{method}{turbofish}()", method = replacement.method_name),
149+
applicability,
150+
);
151+
}
152+
}
153+
154+
pub(super) fn check<'tcx>(
155+
cx: &LateContext<'tcx>,
156+
expr: &hir::Expr<'tcx>,
114157
init: &hir::Expr<'_>,
115158
acc: &hir::Expr<'_>,
116159
fold_span: Span,
@@ -124,60 +167,40 @@ pub(super) fn check(
124167
if let hir::ExprKind::Lit(lit) = init.kind {
125168
match lit.node {
126169
ast::LitKind::Bool(false) => {
127-
check_fold_with_op(
128-
cx,
129-
expr,
130-
acc,
131-
fold_span,
132-
hir::BinOpKind::Or,
133-
Replacement {
134-
method_name: "any",
135-
has_args: true,
136-
has_generic_return: false,
137-
},
138-
);
170+
let replacement = Replacement {
171+
method_name: "any",
172+
has_args: true,
173+
has_generic_return: false,
174+
};
175+
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Or, replacement);
139176
},
140177
ast::LitKind::Bool(true) => {
141-
check_fold_with_op(
142-
cx,
143-
expr,
144-
acc,
145-
fold_span,
146-
hir::BinOpKind::And,
147-
Replacement {
148-
method_name: "all",
149-
has_args: true,
150-
has_generic_return: false,
151-
},
152-
);
178+
let replacement = Replacement {
179+
method_name: "all",
180+
has_args: true,
181+
has_generic_return: false,
182+
};
183+
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::And, replacement);
153184
},
154185
ast::LitKind::Int(Pu128(0), _) => {
155-
check_fold_with_op(
156-
cx,
157-
expr,
158-
acc,
159-
fold_span,
160-
hir::BinOpKind::Add,
161-
Replacement {
162-
method_name: "sum",
163-
has_args: false,
164-
has_generic_return: needs_turbofish(cx, expr),
165-
},
166-
);
186+
let replacement = Replacement {
187+
method_name: "sum",
188+
has_args: false,
189+
has_generic_return: needs_turbofish(cx, expr),
190+
};
191+
if !check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Add, replacement) {
192+
check_fold_with_method(cx, expr, acc, fold_span, sym::add, replacement);
193+
}
167194
},
168195
ast::LitKind::Int(Pu128(1), _) => {
169-
check_fold_with_op(
170-
cx,
171-
expr,
172-
acc,
173-
fold_span,
174-
hir::BinOpKind::Mul,
175-
Replacement {
176-
method_name: "product",
177-
has_args: false,
178-
has_generic_return: needs_turbofish(cx, expr),
179-
},
180-
);
196+
let replacement = Replacement {
197+
method_name: "product",
198+
has_args: false,
199+
has_generic_return: needs_turbofish(cx, expr),
200+
};
201+
if !check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Mul, replacement) {
202+
check_fold_with_method(cx, expr, acc, fold_span, sym::mul, replacement);
203+
}
181204
},
182205
_ => (),
183206
}

tests/ui/unnecessary_fold.fixed

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,35 @@ fn is_any(acc: bool, x: usize) -> bool {
66

77
/// Calls which should trigger the `UNNECESSARY_FOLD` lint
88
fn unnecessary_fold() {
9+
use std::ops::{Add, Mul};
10+
911
// Can be replaced by .any
1012
let _ = (0..3).any(|x| x > 2);
1113
//~^ unnecessary_fold
14+
1215
// Can be replaced by .any (checking suggestion)
1316
let _ = (0..3).fold(false, is_any);
1417
//~^ redundant_closure
18+
1519
// Can be replaced by .all
1620
let _ = (0..3).all(|x| x > 2);
1721
//~^ unnecessary_fold
22+
1823
// Can be replaced by .sum
1924
let _: i32 = (0..3).sum();
2025
//~^ unnecessary_fold
26+
let _: i32 = (0..3).sum();
27+
//~^ unnecessary_fold
28+
let _: i32 = (0..3).sum();
29+
//~^ unnecessary_fold
30+
2131
// Can be replaced by .product
2232
let _: i32 = (0..3).product();
2333
//~^ unnecessary_fold
34+
let _: i32 = (0..3).product();
35+
//~^ unnecessary_fold
36+
let _: i32 = (0..3).product();
37+
//~^ unnecessary_fold
2438
}
2539

2640
/// Should trigger the `UNNECESSARY_FOLD` lint, with an error span including exactly `.fold(...)`
@@ -37,6 +51,43 @@ fn unnecessary_fold_should_ignore() {
3751
let _ = (0..3).fold(0, |acc, x| acc * x);
3852
let _ = (0..3).fold(0, |acc, x| 1 + acc + x);
3953

54+
struct Adder;
55+
impl Adder {
56+
fn add(lhs: i32, rhs: i32) -> i32 {
57+
unimplemented!()
58+
}
59+
fn mul(lhs: i32, rhs: i32) -> i32 {
60+
unimplemented!()
61+
}
62+
}
63+
// `add`/`mul` are inherent methods
64+
let _: i32 = (0..3).fold(0, Adder::add);
65+
let _: i32 = (0..3).fold(1, Adder::mul);
66+
67+
trait FakeAdd<Rhs = Self> {
68+
type Output;
69+
fn add(self, other: Rhs) -> Self::Output;
70+
}
71+
impl FakeAdd for i32 {
72+
type Output = Self;
73+
fn add(self, other: i32) -> Self::Output {
74+
self + other
75+
}
76+
}
77+
trait FakeMul<Rhs = Self> {
78+
type Output;
79+
fn mul(self, other: Rhs) -> Self::Output;
80+
}
81+
impl FakeMul for i32 {
82+
type Output = Self;
83+
fn mul(self, other: i32) -> Self::Output {
84+
self * other
85+
}
86+
}
87+
// `add`/`mul` come from an unrelated trait
88+
let _: i32 = (0..3).fold(0, FakeAdd::add);
89+
let _: i32 = (0..3).fold(1, FakeMul::mul);
90+
4091
// We only match against an accumulator on the left
4192
// hand side. We could lint for .sum and .product when
4293
// it's on the right, but don't for now (and this wouldn't
@@ -63,6 +114,7 @@ fn unnecessary_fold_over_multiple_lines() {
63114
fn issue10000() {
64115
use std::collections::HashMap;
65116
use std::hash::BuildHasher;
117+
use std::ops::{Add, Mul};
66118

67119
fn anything<T>(_: T) {}
68120
fn num(_: i32) {}
@@ -74,23 +126,56 @@ fn issue10000() {
74126
// more cases:
75127
let _ = map.values().sum::<i32>();
76128
//~^ unnecessary_fold
129+
let _ = map.values().sum::<i32>();
130+
//~^ unnecessary_fold
77131
let _ = map.values().product::<i32>();
78132
//~^ unnecessary_fold
133+
let _ = map.values().product::<i32>();
134+
//~^ unnecessary_fold
135+
let _: i32 = map.values().sum();
136+
//~^ unnecessary_fold
79137
let _: i32 = map.values().sum();
80138
//~^ unnecessary_fold
81139
let _: i32 = map.values().product();
82140
//~^ unnecessary_fold
141+
let _: i32 = map.values().product();
142+
//~^ unnecessary_fold
83143
anything(map.values().sum::<i32>());
84144
//~^ unnecessary_fold
145+
anything(map.values().sum::<i32>());
146+
//~^ unnecessary_fold
147+
anything(map.values().product::<i32>());
148+
//~^ unnecessary_fold
85149
anything(map.values().product::<i32>());
86150
//~^ unnecessary_fold
87151
num(map.values().sum());
88152
//~^ unnecessary_fold
153+
num(map.values().sum());
154+
//~^ unnecessary_fold
155+
num(map.values().product());
156+
//~^ unnecessary_fold
89157
num(map.values().product());
90158
//~^ unnecessary_fold
91159
}
92160

93161
smoketest_map(HashMap::new());
162+
163+
fn add_turbofish_not_necessary() -> i32 {
164+
(0..3).sum()
165+
//~^ unnecessary_fold
166+
}
167+
fn mul_turbofish_not_necessary() -> i32 {
168+
(0..3).product()
169+
//~^ unnecessary_fold
170+
}
171+
fn add_turbofish_necessary() -> impl Add {
172+
(0..3).sum::<i32>()
173+
//~^ unnecessary_fold
174+
}
175+
fn mul_turbofish_necessary() -> impl Mul {
176+
(0..3).product::<i32>()
177+
//~^ unnecessary_fold
178+
}
94179
}
95180

96181
fn main() {}

0 commit comments

Comments
 (0)