Skip to content

Commit 34f780d

Browse files
committed
safer interface for ExprLambda's formals
1 parent e438888 commit 34f780d

File tree

12 files changed

+105
-82
lines changed

12 files changed

+105
-82
lines changed

src/libexpr-tests/primops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ TEST_F(PrimOpTest, derivation)
771771
ASSERT_EQ(v.type(), nFunction);
772772
ASSERT_TRUE(v.isLambda());
773773
ASSERT_NE(v.lambda().fun, nullptr);
774-
ASSERT_TRUE(v.lambda().fun->hasFormals);
774+
ASSERT_TRUE(v.lambda().fun->getFormals());
775775
}
776776

777777
TEST_F(PrimOpTest, currentTime)

src/libexpr-tests/value/print.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,8 @@ TEST_F(ValuePrintingTests, vLambda)
110110
PosTable::Origin origin = state.positions.addOrigin(std::monostate(), 1);
111111
auto posIdx = state.positions.add(origin, 0);
112112
auto body = ExprInt(0);
113-
auto formals = Formals{};
114113

115-
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), formals, &body);
114+
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), &body);
116115

117116
Value vLambda;
118117
vLambda.mkLambda(&env, &eLambda);
@@ -500,9 +499,8 @@ TEST_F(ValuePrintingTests, ansiColorsLambda)
500499
PosTable::Origin origin = state.positions.addOrigin(std::monostate(), 1);
501500
auto posIdx = state.positions.add(origin, 0);
502501
auto body = ExprInt(0);
503-
auto formals = Formals{};
504502

505-
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), formals, &body);
503+
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), &body);
506504

507505
Value vLambda;
508506
vLambda.mkLambda(&env, &eLambda);

src/libexpr/eval.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,15 +1496,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
14961496

14971497
ExprLambda & lambda(*vCur.lambda().fun);
14981498

1499-
auto size = (!lambda.arg ? 0 : 1) + (lambda.hasFormals ? lambda.getFormals().size() : 0);
1499+
auto size = (!lambda.arg ? 0 : 1) + (lambda.getFormals() ? lambda.getFormals()->formals.size() : 0);
15001500
Env & env2(mem.allocEnv(size));
15011501
env2.up = vCur.lambda().env;
15021502

15031503
Displacement displ = 0;
15041504

1505-
if (!lambda.hasFormals)
1506-
env2.values[displ++] = args[0];
1507-
else {
1505+
if (auto formals = lambda.getFormals()) {
15081506
try {
15091507
forceAttrs(*args[0], lambda.pos, "while evaluating the value passed for the lambda argument");
15101508
} catch (Error & e) {
@@ -1520,7 +1518,7 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
15201518
there is no matching actual argument but the formal
15211519
argument has a default, use the default. */
15221520
size_t attrsUsed = 0;
1523-
for (auto & i : lambda.getFormals()) {
1521+
for (auto & i : formals->formals) {
15241522
auto j = args[0]->attrs()->get(i.name);
15251523
if (!j) {
15261524
if (!i.def) {
@@ -1542,13 +1540,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
15421540

15431541
/* Check that each actual argument is listed as a formal
15441542
argument (unless the attribute match specifies a `...'). */
1545-
if (!lambda.ellipsis && attrsUsed != args[0]->attrs()->size()) {
1543+
if (!formals->ellipsis && attrsUsed != args[0]->attrs()->size()) {
15461544
/* Nope, so show the first unexpected argument to the
15471545
user. */
15481546
for (auto & i : *args[0]->attrs())
1549-
if (!lambda.hasFormal(i.name)) {
1547+
if (!formals->has(i.name)) {
15501548
StringSet formalNames;
1551-
for (auto & formal : lambda.getFormals())
1549+
for (auto & formal : formals->formals)
15521550
formalNames.insert(std::string(symbols[formal.name]));
15531551
auto suggestions = Suggestions::bestMatches(formalNames, symbols[i.name]);
15541552
error<TypeError>(
@@ -1563,6 +1561,8 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
15631561
}
15641562
unreachable();
15651563
}
1564+
} else {
1565+
env2.values[displ++] = args[0];
15661566
}
15671567

15681568
nrFunctionCalls++;
@@ -1747,22 +1747,23 @@ void EvalState::autoCallFunction(const Bindings & args, Value & fun, Value & res
17471747
}
17481748
}
17491749

1750-
if (!fun.isLambda() || !fun.lambda().fun->hasFormals) {
1750+
if (!fun.isLambda() || !fun.lambda().fun->getFormals()) {
17511751
res = fun;
17521752
return;
17531753
}
1754+
auto formals = fun.lambda().fun->getFormals();
17541755

1755-
auto attrs = buildBindings(std::max(static_cast<uint32_t>(fun.lambda().fun->nFormals), args.size()));
1756+
auto attrs = buildBindings(std::max(static_cast<uint32_t>(formals->formals.size()), args.size()));
17561757

1757-
if (fun.lambda().fun->ellipsis) {
1758+
if (formals->ellipsis) {
17581759
// If the formals have an ellipsis (eg the function accepts extra args) pass
17591760
// all available automatic arguments (which includes arguments specified on
17601761
// the command line via --arg/--argstr)
17611762
for (auto & v : args)
17621763
attrs.insert(v);
17631764
} else {
17641765
// Otherwise, only pass the arguments that the function accepts
1765-
for (auto & i : fun.lambda().fun->getFormals()) {
1766+
for (auto & i : formals->formals) {
17661767
auto j = args.get(i.name);
17671768
if (j) {
17681769
attrs.insert(*j);

src/libexpr/include/nix/expr/nixexpr.hh

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ struct Formal
466466
Expr * def;
467467
};
468468

469-
struct Formals
469+
struct FormalsBuilder
470470
{
471471
typedef std::vector<Formal> Formals_;
472472
/**
@@ -483,26 +483,67 @@ struct Formals
483483
}
484484
};
485485

486+
struct Formals
487+
{
488+
std::span<Formal> formals;
489+
bool ellipsis;
490+
491+
Formals(std::span<Formal> formals, bool ellipsis)
492+
: formals(formals)
493+
, ellipsis(ellipsis) {};
494+
495+
bool has(Symbol arg) const
496+
{
497+
auto it = std::lower_bound(
498+
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
499+
return it != formals.end() && it->name == arg;
500+
}
501+
502+
std::vector<Formal> lexicographicOrder(const SymbolTable & symbols) const
503+
{
504+
std::vector<Formal> result(formals.begin(), formals.end());
505+
std::sort(result.begin(), result.end(), [&](const Formal & a, const Formal & b) {
506+
std::string_view sa = symbols[a.name], sb = symbols[b.name];
507+
return sa < sb;
508+
});
509+
return result;
510+
}
511+
};
512+
486513
struct ExprLambda : Expr
487514
{
488515
PosIdx pos;
489516
Symbol name;
490517
Symbol arg;
491518

492-
bool ellipsis;
519+
private:
493520
bool hasFormals;
521+
bool ellipsis;
494522
uint16_t nFormals;
495523
Formal * formalsStart;
524+
public:
525+
526+
std::optional<Formals> getFormals() const
527+
{
528+
if (hasFormals)
529+
return Formals{{formalsStart, nFormals}, ellipsis};
530+
else
531+
return std::nullopt;
532+
}
496533

497534
Expr * body;
498535
DocComment docComment;
499536

500537
ExprLambda(
501-
std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Symbol arg, const Formals & formals, Expr * body)
538+
std::pmr::polymorphic_allocator<char> & alloc,
539+
PosIdx pos,
540+
Symbol arg,
541+
const FormalsBuilder & formals,
542+
Expr * body)
502543
: pos(pos)
503544
, arg(arg)
504-
, ellipsis(formals.ellipsis)
505545
, hasFormals(true)
546+
, ellipsis(formals.ellipsis)
506547
, nFormals(formals.formals.size())
507548
, formalsStart(alloc.allocate_object<Formal>(nFormals))
508549
, body(body)
@@ -514,44 +555,22 @@ struct ExprLambda : Expr
514555
: pos(pos)
515556
, arg(arg)
516557
, hasFormals(false)
558+
, ellipsis(false)
559+
, nFormals(0)
517560
, formalsStart(nullptr)
518561
, body(body) {};
519562

520-
ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Formals formals, Expr * body)
563+
ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, FormalsBuilder formals, Expr * body)
521564
: ExprLambda(alloc, pos, Symbol(), formals, body) {};
522565

523-
bool hasFormal(Symbol arg) const
524-
{
525-
auto formals = getFormals();
526-
auto it = std::lower_bound(
527-
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
528-
return it != formals.end() && it->name == arg;
529-
}
530-
531566
void setName(Symbol name) override;
532567
std::string showNamePos(const EvalState & state) const;
533568

534-
std::vector<Formal> getFormalsLexicographic(const SymbolTable & symbols) const
535-
{
536-
std::vector<Formal> result(getFormals().begin(), getFormals().end());
537-
std::sort(result.begin(), result.end(), [&](const Formal & a, const Formal & b) {
538-
std::string_view sa = symbols[a.name], sb = symbols[b.name];
539-
return sa < sb;
540-
});
541-
return result;
542-
}
543-
544569
PosIdx getPos() const override
545570
{
546571
return pos;
547572
}
548573

549-
std::span<Formal> getFormals() const
550-
{
551-
assert(hasFormals);
552-
return {formalsStart, nFormals};
553-
}
554-
555574
virtual void setDocComment(DocComment docComment) override;
556575
COMMON_METHODS
557576
};

src/libexpr/include/nix/expr/parser-state.hh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ struct ParserState
9393
void addAttr(
9494
ExprAttrs * attrs, AttrPath && attrPath, const ParserLocation & loc, Expr * e, const ParserLocation & exprLoc);
9595
void addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symbol, ExprAttrs::AttrDef && def);
96-
void validateFormals(Formals & formals, PosIdx pos = noPos, Symbol arg = {});
96+
void validateFormals(FormalsBuilder & formals, PosIdx pos = noPos, Symbol arg = {});
9797
Expr * stripIndentation(const PosIdx pos, std::vector<std::pair<PosIdx, std::variant<Expr *, StringToken>>> && es);
9898
PosIdx at(const ParserLocation & loc);
9999
};
@@ -213,7 +213,7 @@ ParserState::addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symb
213213
}
214214
}
215215

216-
inline void ParserState::validateFormals(Formals & formals, PosIdx pos, Symbol arg)
216+
inline void ParserState::validateFormals(FormalsBuilder & formals, PosIdx pos, Symbol arg)
217217
{
218218
std::sort(formals.formals.begin(), formals.formals.end(), [](const auto & a, const auto & b) {
219219
return std::tie(a.name, a.pos) < std::tie(b.name, b.pos);

src/libexpr/nixexpr.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ void ExprList::show(const SymbolTable & symbols, std::ostream & str) const
154154
void ExprLambda::show(const SymbolTable & symbols, std::ostream & str) const
155155
{
156156
str << "(";
157-
if (hasFormals) {
157+
if (auto formals = getFormals()) {
158158
str << "{ ";
159159
bool first = true;
160160
// the natural Symbol ordering is by creation time, which can lead to the
161161
// same expression being printed in two different ways depending on its
162162
// context. always use lexicographic ordering to avoid this.
163-
for (auto & i : getFormalsLexicographic(symbols)) {
163+
for (auto & i : formals->lexicographicOrder(symbols)) {
164164
if (first)
165165
first = false;
166166
else
@@ -451,20 +451,21 @@ void ExprLambda::bindVars(EvalState & es, const std::shared_ptr<const StaticEnv>
451451
if (es.debugRepl)
452452
es.exprEnvs.insert(std::make_pair(this, env));
453453

454-
auto newEnv = std::make_shared<StaticEnv>(nullptr, env, (hasFormals ? getFormals().size() : 0) + (!arg ? 0 : 1));
454+
auto newEnv =
455+
std::make_shared<StaticEnv>(nullptr, env, (getFormals() ? getFormals()->formals.size() : 0) + (!arg ? 0 : 1));
455456

456457
Displacement displ = 0;
457458

458459
if (arg)
459460
newEnv->vars.emplace_back(arg, displ++);
460461

461-
if (hasFormals) {
462-
for (auto & i : getFormals())
462+
if (auto formals = getFormals()) {
463+
for (auto & i : formals->formals)
463464
newEnv->vars.emplace_back(i.name, displ++);
464465

465466
newEnv->sort();
466467

467-
for (auto & i : getFormals())
468+
for (auto & i : formals->formals)
468469
if (i.def)
469470
i.def->bindVars(es, newEnv);
470471
}

src/libexpr/parser.y

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ static Expr * makeCall(PosIdx pos, Expr * fn, Expr * arg) {
131131
%type <nix::Expr *> expr_pipe_from expr_pipe_into
132132
%type <std::vector<Expr *>> list
133133
%type <nix::ExprAttrs *> binds binds1
134-
%type <nix::Formals> formals formal_set
134+
%type <nix::FormalsBuilder> formals formal_set
135135
%type <nix::Formal> formal
136136
%type <std::vector<nix::AttrName>> attrpath
137137
%type <std::vector<std::pair<nix::AttrName, nix::PosIdx>>> attrs

src/libexpr/primops.cc

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3363,21 +3363,20 @@ static void prim_functionArgs(EvalState & state, const PosIdx pos, Value ** args
33633363
if (!args[0]->isLambda())
33643364
state.error<TypeError>("'functionArgs' requires a function").atPos(pos).debugThrow();
33653365

3366-
if (!args[0]->lambda().fun->hasFormals) {
3366+
if (const auto & formals = args[0]->lambda().fun->getFormals()) {
3367+
auto attrs = state.buildBindings(formals->formals.size());
3368+
for (auto & i : formals->formals)
3369+
attrs.insert(i.name, state.getBool(i.def), i.pos);
3370+
/* Optimization: avoid sorting bindings. `formals` must already be sorted according to
3371+
(std::tie(a.name, a.pos) < std::tie(b.name, b.pos)) predicate, so the following assertion
3372+
always holds:
3373+
assert(std::is_sorted(attrs.alreadySorted()->begin(), attrs.alreadySorted()->end()));
3374+
.*/
3375+
v.mkAttrs(attrs.alreadySorted());
3376+
} else {
33673377
v.mkAttrs(&Bindings::emptyBindings);
33683378
return;
33693379
}
3370-
3371-
const auto & formals = args[0]->lambda().fun->getFormals();
3372-
auto attrs = state.buildBindings(formals.size());
3373-
for (auto & i : formals)
3374-
attrs.insert(i.name, state.getBool(i.def), i.pos);
3375-
/* Optimization: avoid sorting bindings. `formals` must already be sorted according to
3376-
(std::tie(a.name, a.pos) < std::tie(b.name, b.pos)) predicate, so the following assertion
3377-
always holds:
3378-
assert(std::is_sorted(attrs.alreadySorted()->begin(), attrs.alreadySorted()->end()));
3379-
.*/
3380-
v.mkAttrs(attrs.alreadySorted());
33813380
}
33823381

33833382
static RegisterPrimOp primop_functionArgs({

src/libexpr/value-to-xml.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,14 @@ static void printValueAsXML(
145145
posToXML(state, xmlAttrs, state.positions[v.lambda().fun->pos]);
146146
XMLOpenElement _(doc, "function", xmlAttrs);
147147

148-
if (v.lambda().fun->hasFormals) {
148+
if (auto formals = v.lambda().fun->getFormals()) {
149149
XMLAttrs attrs;
150150
if (v.lambda().fun->arg)
151151
attrs["name"] = state.symbols[v.lambda().fun->arg];
152-
if (v.lambda().fun->ellipsis)
152+
if (formals->ellipsis)
153153
attrs["ellipsis"] = "1";
154154
XMLOpenElement _(doc, "attrspat", attrs);
155-
for (auto & i : v.lambda().fun->getFormalsLexicographic(state.symbols))
155+
for (auto & i : formals->lexicographicOrder(state.symbols))
156156
doc.writeEmptyElement("attr", singletonAttrs("name", state.symbols[i.name]));
157157
} else
158158
doc.writeEmptyElement("varpat", singletonAttrs("name", state.symbols[v.lambda().fun->arg]));

src/libflake/flake.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,15 @@ static Flake readFlake(
281281
if (auto outputs = vInfo.attrs()->get(sOutputs)) {
282282
expectType(state, nFunction, *outputs->value, outputs->pos);
283283

284-
if (outputs->value->isLambda() && outputs->value->lambda().fun->hasFormals) {
285-
for (auto & formal : outputs->value->lambda().fun->getFormals()) {
286-
if (formal.name != state.s.self)
287-
flake.inputs.emplace(
288-
state.symbols[formal.name],
289-
FlakeInput{.ref = parseFlakeRef(state.fetchSettings, std::string(state.symbols[formal.name]))});
284+
if (outputs->value->isLambda()) {
285+
if (auto formals = outputs->value->lambda().fun->getFormals()) {
286+
for (auto & formal : formals->formals) {
287+
if (formal.name != state.s.self)
288+
flake.inputs.emplace(
289+
state.symbols[formal.name],
290+
FlakeInput{
291+
.ref = parseFlakeRef(state.fetchSettings, std::string(state.symbols[formal.name]))});
292+
}
290293
}
291294
}
292295

0 commit comments

Comments
 (0)