Skip to content

Commit d54412d

Browse files
committed
Add DqJoin with label list
1 parent 64b5be5 commit d54412d

File tree

2 files changed

+92
-31
lines changed

2 files changed

+92
-31
lines changed

ydb/library/yql/dq/opt/dq_opt_join.cpp

+15-6
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ TExprBase BuildDqJoinInput(TExprContext& ctx, TPositionHandle pos, const TExprBa
116116
return partition;
117117
}
118118

119+
TExprNode::TPtr CreateLabelList(const THashSet<TStringBuf>& labels, const TPositionHandle& position, TExprContext& ctx) {
120+
TExprNode::TListType newKeys;
121+
for (const auto& label : labels) {
122+
newKeys.push_back(ctx.NewAtom(position, label));
123+
}
124+
return ctx.NewList(position, std::move(newKeys));
125+
}
126+
119127
TMaybe<TJoinInputDesc> BuildDqJoin(
120128
const TCoEquiJoinTuple& joinTuple,
121129
const THashMap<TStringBuf, TJoinInputDesc>& inputs,
@@ -191,12 +199,13 @@ TMaybe<TJoinInputDesc> BuildDqJoin(
191199
resultKeys.insert(right->Keys.begin(), right->Keys.end());
192200
}
193201

194-
auto leftTableLabel = left->IsRealTable()
195-
? BuildAtom(leftLabel, left->Input.Pos(), ctx).Ptr()
196-
: Build<TCoVoid>(ctx, left->Input.Pos()).Done().Ptr();
197-
auto rightTableLabel = right->IsRealTable()
198-
? BuildAtom(rightLabel, right->Input.Pos(), ctx).Ptr()
199-
: Build<TCoVoid>(ctx, right->Input.Pos()).Done().Ptr();
202+
auto leftTableLabel = left->IsRealTable() ? (left->Labels->size() > 1 ? CreateLabelList(*(left->Labels), left->Input.Pos(), ctx)
203+
: BuildAtom(leftLabel, left->Input.Pos(), ctx).Ptr())
204+
: Build<TCoVoid>(ctx, left->Input.Pos()).Done().Ptr();
205+
206+
auto rightTableLabel = right->IsRealTable() ? (right->Labels->size() > 1 ? CreateLabelList(*(right->Labels), right->Input.Pos(), ctx)
207+
: BuildAtom(rightLabel, right->Input.Pos(), ctx).Ptr())
208+
: Build<TCoVoid>(ctx, right->Input.Pos()).Done().Ptr();
200209

201210
size_t joinKeysCount = joinTuple.LeftKeys().Size() / 2;
202211
TVector<TCoAtom> leftJoinKeys;

ydb/library/yql/dq/type_ann/dq_type_ann.cpp

+77-25
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ TStatus AnnotateStage(const TExprNode::TPtr& stage, TExprContext& ctx) {
258258
}
259259

260260
THashMap<TStringBuf, THashMap<TStringBuf, const TTypeAnnotationNode*>>
261-
ParseJoinInputType(const TStructExprType& rowType, TStringBuf tableLabel, TExprContext& ctx, bool optional) {
261+
ParseJoinInputType(const TStructExprType& rowType, const THashSet<TStringBuf>& tableLabels, TExprContext& ctx, bool optional) {
262262
THashMap<TStringBuf, THashMap<TStringBuf, const TTypeAnnotationNode*>> result;
263263
for (auto member : rowType.GetItems()) {
264264
TStringBuf label, column;
@@ -268,7 +268,7 @@ ParseJoinInputType(const TStructExprType& rowType, TStringBuf tableLabel, TExprC
268268
column = member->GetName();
269269
}
270270
const bool isSystemKeyColumn = column.starts_with("_yql_dq_key_");
271-
if (label.empty() && tableLabel.empty() && !isSystemKeyColumn) {
271+
if (label.empty() && (tableLabels.size() == 1 && tableLabels.begin()->empty()) && !isSystemKeyColumn) {
272272
ctx.AddError(TIssue(TStringBuilder() << "Invalid join input type " << FormatType(&rowType)));
273273
result.clear();
274274
return result;
@@ -277,23 +277,30 @@ ParseJoinInputType(const TStructExprType& rowType, TStringBuf tableLabel, TExprC
277277
if (optional && !memberType->IsOptionalOrNull()) {
278278
memberType = ctx.MakeType<TOptionalExprType>(memberType);
279279
}
280-
if (!tableLabel.empty() && label.empty()) {
281-
result[tableLabel][member->GetName()] = memberType;
282-
} else {
280+
if (tableLabels.size() > 1) {
281+
YQL_ENSURE(label);
282+
YQL_ENSURE(column);
283283
result[label][column] = memberType;
284+
} else {
285+
YQL_ENSURE(tableLabels.size() == 1);
286+
if (!(tableLabels.begin())->empty()) {
287+
result[*(tableLabels.begin())][member->GetName()] = memberType;
288+
} else {
289+
result[label][column] = memberType;
290+
}
284291
}
285292
}
286293
return result;
287294
}
288295

289296
template <bool IsMapJoin>
290297
const TStructExprType* GetDqJoinResultType(TPositionHandle pos, const TStructExprType& leftRowType,
291-
const TStringBuf& leftLabel, const TStructExprType& rightRowType, const TStringBuf& rightLabel,
298+
const THashSet<TStringBuf>& leftLabels, const TStructExprType& rightRowType, const THashSet<TStringBuf>& rightLabels,
292299
const TStringBuf& joinType, const TDqJoinKeyTupleList& joinKeys, TExprContext& ctx)
293300
{
294301
// check left
295302
bool isLeftOptional = IsLeftJoinSideOptional(joinType);
296-
auto leftType = ParseJoinInputType(leftRowType, leftLabel, ctx, isLeftOptional);
303+
auto leftType = ParseJoinInputType(leftRowType, leftLabels, ctx, isLeftOptional);
297304
if (leftType.empty() && joinType != "Cross") {
298305
TStringStream str; str << "Cannot parse left join input type: ";
299306
leftRowType.Out(str);
@@ -303,7 +310,7 @@ const TStructExprType* GetDqJoinResultType(TPositionHandle pos, const TStructExp
303310

304311
// check right
305312
bool isRightOptional = IsRightJoinSideOptional(joinType);
306-
auto rightType = ParseJoinInputType(rightRowType, rightLabel, ctx, isRightOptional);
313+
auto rightType = ParseJoinInputType(rightRowType, rightLabels, ctx, isRightOptional);
307314
if (rightType.empty() && joinType != "Cross") {
308315
TStringStream str; str << "Cannot parse right join input type: ";
309316
rightRowType.Out(str);
@@ -331,11 +338,11 @@ const TStructExprType* GetDqJoinResultType(TPositionHandle pos, const TStructExp
331338
auto rightKeyLabel = key.RightLabel().Value();
332339
auto rightKeyColumn = key.RightColumn().Value();
333340

334-
if (leftLabel && leftLabel != leftKeyLabel) {
341+
if ((leftLabels.size() && !leftLabels.begin()->empty()) && !leftLabels.contains(leftKeyLabel)) {
335342
ctx.AddError(TIssue(ctx.GetPosition(pos), "different labels for left table"));
336343
return nullptr;
337344
}
338-
if (rightLabel && rightLabel != rightKeyLabel) {
345+
if ((rightLabels.size() && !rightLabels.begin()->empty()) && !rightLabels.contains(rightKeyLabel)) {
339346
ctx.AddError(TIssue(ctx.GetPosition(pos), "different labels for right table"));
340347
return nullptr;
341348
}
@@ -402,14 +409,26 @@ const TStructExprType* GetDqJoinResultType(const TExprNode::TPtr& input, bool st
402409
}
403410

404411
if (!input->Child(TDqJoin::idx_LeftLabel)->IsCallable("Void")) {
405-
if (!EnsureAtom(*input->Child(TDqJoin::idx_LeftLabel), ctx)) {
406-
return nullptr;
412+
if ((input->Child(TDqJoin::idx_LeftLabel)->IsAtom())) {
413+
if (!EnsureAtom(*input->Child(TDqJoin::idx_LeftLabel), ctx)) {
414+
return nullptr;
415+
}
416+
} else {
417+
if (!EnsureTupleOfAtoms(*input->Child(TDqJoin::idx_LeftLabel), ctx)) {
418+
return nullptr;
419+
}
407420
}
408421
}
409422

410423
if (!input->Child(TDqJoin::idx_RightLabel)->IsCallable("Void")) {
411-
if (!EnsureAtom(*input->Child(TDqJoin::idx_RightLabel), ctx)) {
412-
return nullptr;
424+
if ((input->Child(TDqJoin::idx_RightLabel)->IsAtom())) {
425+
if (!EnsureAtom(*input->Child(TDqJoin::idx_RightLabel), ctx)) {
426+
return nullptr;
427+
}
428+
} else {
429+
if (!EnsureTupleOfAtoms(*input->Child(TDqJoin::idx_RightLabel), ctx)) {
430+
return nullptr;
431+
}
413432
}
414433
}
415434

@@ -459,18 +478,32 @@ const TStructExprType* GetDqJoinResultType(const TExprNode::TPtr& input, bool st
459478
return nullptr;
460479
}
461480
auto leftStructType = leftInputItemType.Cast<TStructExprType>();
462-
auto leftTableLabel = join.LeftLabel().Maybe<TCoAtom>()
463-
? join.LeftLabel().Cast<TCoAtom>().Value()
464-
: TStringBuf("");
481+
THashSet<TStringBuf> leftTableLabels;
482+
if (join.LeftLabel().Maybe<TCoAtom>()) {
483+
leftTableLabels.emplace(join.LeftLabel().Cast<TCoAtom>().Value());
484+
} else if (join.LeftLabel().Maybe<TCoAtomList>()) {
485+
for (auto label : join.LeftLabel().Cast<TCoAtomList>()) {
486+
leftTableLabels.emplace(label.Value());
487+
}
488+
} else {
489+
leftTableLabels.emplace("");
490+
}
465491

466492
const auto& rightInputItemType = GetSeqItemType(*rightInputType);
467493
if (!EnsureStructType(join.Pos(), rightInputItemType, ctx)) {
468494
return nullptr;
469495
}
470496
auto rightStructType = rightInputItemType.Cast<TStructExprType>();
471-
auto rightTableLabel = join.RightLabel().Maybe<TCoAtom>()
472-
? join.RightLabel().Cast<TCoAtom>().Value()
473-
: TStringBuf("");
497+
THashSet<TStringBuf> rightTableLabels;
498+
if (join.RightLabel().Maybe<TCoAtom>()) {
499+
rightTableLabels.emplace(join.RightLabel().Cast<TCoAtom>().Value());
500+
} else if (join.RightLabel().Maybe<TCoAtomList>()) {
501+
for (auto label : join.RightLabel().Cast<TCoAtomList>()) {
502+
rightTableLabels.emplace(label.Value());
503+
}
504+
} else {
505+
rightTableLabels.emplace("");
506+
}
474507

475508
if (input->ChildrenSize() > TDqJoin::idx_JoinAlgoOptions) {
476509
const auto& joinAlgo = *input->Child(TDqJoin::idx_JoinAlgo);
@@ -511,9 +544,9 @@ const TStructExprType* GetDqJoinResultType(const TExprNode::TPtr& input, bool st
511544
}
512545
}
513546

514-
return GetDqJoinResultType<IsMapJoin>(join.Pos(), *leftStructType, leftTableLabel, *rightStructType,
515-
rightTableLabel, join.JoinType(), join.JoinKeys(), ctx);
516-
}
547+
return GetDqJoinResultType<IsMapJoin>(join.Pos(), *leftStructType, leftTableLabels, *rightStructType,
548+
rightTableLabels, join.JoinType(), join.JoinKeys(), ctx);
549+
}
517550

518551
} // unnamed
519552

@@ -689,12 +722,31 @@ TStatus AnnotateDqCnStreamLookup(const TExprNode::TPtr& input, TExprContext& ctx
689722
if (!EnsureStructType(input->Pos(), rightRowType, ctx)) {
690723
return TStatus::Error;
691724
}
725+
726+
THashSet<TStringBuf> leftLabels;
727+
if (cnStreamLookup.LeftLabel().Maybe<TCoAtom>()) {
728+
leftLabels.emplace(cnStreamLookup.LeftLabel().Cast<TCoAtom>().Value());
729+
} else {
730+
for (auto label : cnStreamLookup.LeftLabel().Cast<TCoAtomList>()) {
731+
leftLabels.emplace(label.Value());
732+
}
733+
}
734+
735+
THashSet<TStringBuf> rightLabels;
736+
if (cnStreamLookup.RightLabel().Maybe<TCoAtom>()) {
737+
rightLabels.emplace(cnStreamLookup.RightLabel().Cast<TCoAtom>().Value());
738+
} else {
739+
for (auto label : cnStreamLookup.RightLabel().Cast<TCoAtomList>()) {
740+
rightLabels.emplace(label.Value());
741+
}
742+
}
743+
692744
const auto outputRowType = GetDqJoinResultType<true>(
693745
input->Pos(),
694746
*leftRowType.Cast<TStructExprType>(),
695-
cnStreamLookup.LeftLabel().Cast<TCoAtom>().StringValue(),
747+
leftLabels,
696748
*rightRowType.Cast<TStructExprType>(),
697-
cnStreamLookup.RightLabel().StringValue(),
749+
rightLabels,
698750
cnStreamLookup.JoinType().StringValue(),
699751
cnStreamLookup.JoinKeys(),
700752
ctx

0 commit comments

Comments
 (0)