Skip to content

Commit 06ea8e3

Browse files
rohanyInfinoid
authored andcommitted
taco: add parser support for windowing/striding/index sets
Fixes tensor-compiler#413. This commit adds support for the command line tool to accept index expressions containing windowing, striding and index sets. An example of each of these features is added to the help message of taco: ``` taco "a(i) = b(i(1, 5))" -d=a:4 # Slice b[1:4] taco "a(i) = b(i(1, 5, 2))" -d=a:2 # Slice b[1:4:2] taco "a(i) = b(i({1, 3, 5, 7}))" -d=a:4 # Slice b[[1, 3, 5, 7]] ```
1 parent ceeabe4 commit 06ea8e3

File tree

6 files changed

+90
-14
lines changed

6 files changed

+90
-14
lines changed

include/taco/index_notation/index_notation.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ Multi multi(IndexStmt stmt1, IndexStmt stmt2);
877877
class IndexVarInterface {
878878
public:
879879
virtual ~IndexVarInterface() = default;
880+
virtual IndexVar getIndexVar() const = 0;
880881

881882
/// match performs a dynamic case analysis of the implementers of IndexVarInterface
882883
/// as a utility for handling the different values within. It mimics the dynamic
@@ -912,7 +913,7 @@ class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public Index
912913
~WindowedIndexVar() = default;
913914

914915
/// getIndexVar returns the underlying IndexVar.
915-
IndexVar getIndexVar() const;
916+
IndexVar getIndexVar() const override;
916917

917918
/// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of
918919
/// this index variable.
@@ -940,7 +941,7 @@ class IndexSetVar : public util::Comparable<IndexSetVar>, public IndexVarInterfa
940941
~IndexSetVar() = default;
941942

942943
/// getIndexVar returns the underlying IndexVar.
943-
IndexVar getIndexVar() const;
944+
IndexVar getIndexVar() const override;
944945
/// getIndexSet returns the index set.
945946
const std::vector<int>& getIndexSet() const;
946947

@@ -957,6 +958,9 @@ class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
957958
~IndexVar() = default;
958959
IndexVar(const std::string& name);
959960

961+
// getIndexVar implements the IndexVarInterface.
962+
IndexVar getIndexVar() const override { return *this; }
963+
960964
/// Returns the name of the index variable.
961965
std::string getName() const;
962966

@@ -967,7 +971,7 @@ class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
967971
WindowedIndexVar operator()(int lo, int hi, int stride = 1);
968972

969973
/// Indexing into an IndexVar with a vector returns an index set into it.
970-
IndexSetVar operator()(std::vector<int> indexSet);
974+
IndexSetVar operator()(std::vector<int>&& indexSet);
971975
IndexSetVar operator()(std::vector<int>& indexSet);
972976

973977
private:

include/taco/parser/parser.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace taco {
1414
class TensorBase;
1515
class Format;
1616
class IndexVar;
17+
class IndexVarInterface;
1718
class IndexExpr;
1819
class Access;
1920

@@ -88,10 +89,13 @@ class Parser : public util::Uncopyable {
8889
Access parseAccess();
8990

9091
/// varlist ::= var {, var}
91-
std::vector<IndexVar> parseVarList();
92+
std::vector<std::shared_ptr<IndexVarInterface>> parseVarList();
9293

9394
/// var ::= identifier
94-
IndexVar parseVar();
95+
/// | identifier '(' int ',' int ')' -- Windowed access.
96+
/// | identifier '(' int ',' int ',' int ')' -- Windowed access with a stride.
97+
/// | identifier '(' '{' int, ... '}' ')' -- Access with an index set.
98+
std::shared_ptr<IndexVarInterface> parseVar();
9599

96100
std::string currentTokenString();
97101

src/index_notation/index_notation.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1978,7 +1978,7 @@ WindowedIndexVar IndexVar::operator()(int lo, int hi, int stride) {
19781978
return WindowedIndexVar(*this, lo, hi, stride);
19791979
}
19801980

1981-
IndexSetVar IndexVar::operator()(std::vector<int> indexSet) {
1981+
IndexSetVar IndexVar::operator()(std::vector<int>&& indexSet) {
19821982
return IndexSetVar(*this, indexSet);
19831983
}
19841984

src/parser/parser.cpp

+72-7
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ Access Parser::parseAccess() {
282282
consume(Token::identifier);
283283
names.push_back(tensorName);
284284

285-
vector<IndexVar> varlist;
285+
vector<std::shared_ptr<IndexVarInterface>> varlist;
286286
if (content->currentToken == Token::underscore) {
287287
consume(Token::underscore);
288288
if (content->currentToken == Token::lcurly) {
@@ -322,8 +322,8 @@ Access Parser::parseAccess() {
322322
if (util::contains(content->tensorDimensions, tensorName)) {
323323
tensorDimensions[i] = content->tensorDimensions.at(tensorName)[i];
324324
}
325-
else if (util::contains(content->indexVarDimensions, varlist[i])) {
326-
tensorDimensions[i] = content->indexVarDimensions.at(varlist[i]);
325+
else if (util::contains(content->indexVarDimensions, varlist[i]->getIndexVar())) {
326+
tensorDimensions[i] = content->indexVarDimensions.at(varlist[i]->getIndexVar());
327327
}
328328
else {
329329
tensorDimensions[i] = content->defaultDimension;
@@ -347,8 +347,8 @@ Access Parser::parseAccess() {
347347
return tensor(varlist);
348348
}
349349

350-
vector<IndexVar> Parser::parseVarList() {
351-
vector<IndexVar> varlist;
350+
vector<std::shared_ptr<IndexVarInterface>> Parser::parseVarList() {
351+
vector<std::shared_ptr<IndexVarInterface>> varlist;
352352
varlist.push_back(parseVar());
353353
while (content->currentToken == Token::comma) {
354354
consume(Token::comma);
@@ -357,13 +357,78 @@ vector<IndexVar> Parser::parseVarList() {
357357
return varlist;
358358
}
359359

360-
IndexVar Parser::parseVar() {
360+
std::shared_ptr<IndexVarInterface> Parser::parseVar() {
361361
if (content->currentToken != Token::identifier) {
362362
throw ParseError("Expected index variable");
363363
}
364364
IndexVar var = getIndexVar(content->lexer.getIdentifier());
365365
consume(Token::identifier);
366-
return var;
366+
// If there is a paren after this identifier, then we may have a window
367+
// or index set access.
368+
if (this->content->currentToken == Token::lparen) {
369+
this->consume(Token::lparen);
370+
switch (this->content->currentToken) {
371+
case Token::int_scalar: {
372+
// In this case, we have a window or strided window. Start off by
373+
// parsing the lo and hi of the window.
374+
int lo, hi;
375+
// Parse out lo.
376+
std::istringstream value(this->content->lexer.getIdentifier());
377+
value >> lo;
378+
this->consume(Token::int_scalar);
379+
380+
// Parse the comma.
381+
this->consume(Token::comma);
382+
383+
// Parse out hi.
384+
value = std::istringstream(this->content->lexer.getIdentifier());
385+
value >> hi;
386+
this->consume(Token::int_scalar);
387+
388+
// Now, there might be the stride. If there is another comma, then there
389+
// is a stride value to parse. Otherwise, it's just the window of (lo, hi).
390+
if (this->content->currentToken == Token::comma) {
391+
this->consume(Token::comma);
392+
int stride;
393+
value = std::istringstream(this->content->lexer.getIdentifier());
394+
value >> stride;
395+
this->consume(Token::int_scalar);
396+
this->consume(Token::rparen);
397+
return std::make_shared<WindowedIndexVar>(var(lo, hi, stride));
398+
} else {
399+
this->consume(Token::rparen);
400+
return std::make_shared<WindowedIndexVar>(var(lo, hi));
401+
}
402+
}
403+
case Token::lcurly: {
404+
// If we see a curly brace, then an index set is being applied to the
405+
// IndexVar. So, we'll parse a list of integers.
406+
this->consume(Token::lcurly);
407+
std::vector<int> indexSet;
408+
bool first = true;
409+
do {
410+
// If this isn't the first iteration of the loop, consume a comma.
411+
if (!first) {
412+
this->consume(Token::comma);
413+
}
414+
first = false;
415+
// Parse and consume the next integer.
416+
std::istringstream value(this->content->lexer.getIdentifier());
417+
int index;
418+
value >> index;
419+
indexSet.push_back(index);
420+
this->consume(Token::int_scalar);
421+
// Break when we hit a '}' to end the list.
422+
} while (this->content->currentToken != Token::rcurly);
423+
this->consume(Token::rcurly);
424+
this->consume(Token::rparen);
425+
return std::make_shared<IndexSetVar>(var(indexSet));
426+
}
427+
default:
428+
throw ParseError("Expected windowing expression.");
429+
}
430+
}
431+
return std::make_shared<IndexVar>(var);
367432
}
368433

369434
bool Parser::hasIndexVar(std::string name) const {

src/tensor.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ struct AccessTensorNode : public AccessNode {
486486
// Ensure that it has at most dim(t, i) elements.
487487
taco_uassert(indexSet.size() <= size_t(tensor.getDimension(i)));
488488
// Pack up the index set into a sparse tensor.
489-
TensorBase indexSetTensor(tensor.getComponentType(), {int(indexSet.size())}, Compressed);
489+
Tensor<int> indexSetTensor({int(indexSet.size())}, Compressed);
490490
for (auto& coord : indexSet) {
491491
indexSetTensor.insert({coord}, 1);
492492
}

tools/taco.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ static void printUsageInfo() {
9090
cout << " taco \"a(i) = b(i) + c(i)\" -f=b:s -f=c:s -f=a:s # Sparse vector add" << endl;
9191
cout << " taco \"a(i) = B(i,j) * c(j)\" -f=B:ds # SpMV" << endl;
9292
cout << " taco \"A(i,l) = B(i,j,k) * C(j,l) * D(k,l)\" -f=B:sss # MTTKRP" << endl;
93+
cout << " taco \"a(i) = b(i(1, 5))\" -d=a:4 # Slice b[1:4]" << endl;
94+
cout << " taco \"a(i) = b(i(1, 5, 2))\" -d=a:2 # Slice b[1:4:2]" << endl;
95+
cout << " taco \"a(i) = b(i({1, 3, 5, 7}))\" -d=a:4 # Slice b[[1, 3, 5, 7]]" << endl;
9396
cout << endl;
9497
cout << "Options:" << endl;
9598
printFlag("d=<var/tensor>:<size>",

0 commit comments

Comments
 (0)