Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion generative_computing/cc/authoring/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ cc_library(
hdrs = ["constructor.h"],
deps = [
"//generative_computing/cc/intrinsics:intrinsic_uris",
"//generative_computing/cc/runtime:status_macros",
"//generative_computing/proto/v0:computation_cc_proto",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
8 changes: 8 additions & 0 deletions generative_computing/cc/authoring/constructor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ v0::Value CreateLabeledValue(std::string label, const v0::Value& value) {

} // namespace

absl::StatusOr<v0::Value> SmartChain::Build() {
// Enable BreakableChain.
if (num_iteration_ > 1) {
return CreateLoopChainCombo(num_iteration_, chained_ops_);
}
return CreateBasicChain(chained_ops_);
}

absl::StatusOr<v0::Value> CreateRepeat(int num_steps, v0::Value body_fn) {
v0::Value repeat_pb;
v0::Intrinsic* const intrinsic_pb = repeat_pb.mutable_intrinsic();
Expand Down
38 changes: 37 additions & 1 deletion generative_computing/cc/authoring/constructor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License
#ifndef GENERATIVE_COMPUTING_CC_AUTHORING_CONSTRUCTOR_H_
#define GENERATIVE_COMPUTING_CC_AUTHORING_CONSTRUCTOR_H_

#include <string>
#include <vector>

#include "absl/status/statusor.h"
Expand All @@ -25,6 +24,43 @@ limitations under the License

namespace generative_computing {

// Used to build chains with loops and breakpoints. Will always choose the
// simpler and more efficient chain type based on chained computations.
class SmartChain {
public:
explicit SmartChain(int num_iteration = 0) : num_iteration_(num_iteration) {}

// Enable Pipe operator op1 | op2 | op3 ...
SmartChain& operator|(const v0::Value& op) {
chained_ops_.push_back(op);
return *this;
}

SmartChain& operator|(const absl::StatusOr<v0::Value>& op) {
chained_ops_.push_back(op.value());
return *this;
}

SmartChain& operator|(SmartChain& other_chain) {
chained_ops_.push_back(other_chain.Build().value());
return *this;
}

// Sets number of iterations
SmartChain& operator|(int num_iteration) {
SetNumIteration(num_iteration);
return *this;
}

absl::StatusOr<v0::Value> Build();

void SetNumIteration(int i) { num_iteration_ = i; }

private:
int num_iteration_;
std::vector<v0::Value> chained_ops_;
};

// Given arg_name & computation body create a Lambda that applies a computation
// to the provided argument.
absl::StatusOr<v0::Value> CreateLambda(absl::string_view arg_name,
Expand Down
20 changes: 20 additions & 0 deletions generative_computing/cc/authoring/constructor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License
#include "generative_computing/cc/authoring/constructor.h"

#include <string>
#include <vector>

#include "googletest/include/gtest/gtest.h"
#include "absl/container/flat_hash_map.h"
Expand Down Expand Up @@ -87,5 +88,24 @@ TEST(CreateWhileTest, ReturnsCorrectWhileProto) {
EXPECT_EQ(kwargs.at("condition_fn").intrinsic().uri(), "regex_partial_match");
EXPECT_EQ(kwargs.at("body_fn").intrinsic().uri(), "model_inference");
}

TEST(SmartChainTest, BuildsLoopChainComboViaPipe) {
v0::Value append_foo_fn = CreateModelInference("append_foo").value();
v0::Value append_bar_fn = CreateModelInference("append_bar").value();
v0::Value if_finish_then_break_fn = CreateRegexPartialMatch("FINISH").value();
v0::Value count_to_3_append_finish_fn =
CreateModelInference("append_finish_when_counts_reaches_3").value();
v0::Value expected_computation =
CreateLoopChainCombo(100, {append_foo_fn, if_finish_then_break_fn,
append_bar_fn, count_to_3_append_finish_fn})
.value();

SmartChain smart_chain = SmartChain() | append_foo_fn |
if_finish_then_break_fn | append_bar_fn |
count_to_3_append_finish_fn | 100;
v0::Value computation = smart_chain.Build().value();
EXPECT_EQ(computation.DebugString(), expected_computation.DebugString());
}

} // namespace
} // namespace generative_computing