From 8e13fc9b024a36c9931eb343d048561698294dd5 Mon Sep 17 00:00:00 2001 From: Ke Wang Date: Thu, 16 Nov 2023 12:21:11 -0800 Subject: [PATCH] internal PiperOrigin-RevId: 583125869 --- generative_computing/cc/authoring/BUILD | 1 - .../cc/authoring/constructor.cc | 8 ++++ .../cc/authoring/constructor.h | 38 ++++++++++++++++++- .../cc/authoring/constructor_test.cc | 20 ++++++++++ 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/generative_computing/cc/authoring/BUILD b/generative_computing/cc/authoring/BUILD index abef064e..4caa4b6a 100644 --- a/generative_computing/cc/authoring/BUILD +++ b/generative_computing/cc/authoring/BUILD @@ -12,7 +12,6 @@ cc_library( hdrs = ["constructor.h"], deps = [ "//generative_computing/cc/intrinsics:intrinsic_uris", - "//generative_computing/cc/runtime:status_macros", "//generative_computing/proto/v0:computation_cc_proto", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/generative_computing/cc/authoring/constructor.cc b/generative_computing/cc/authoring/constructor.cc index 91f308b9..45d83186 100644 --- a/generative_computing/cc/authoring/constructor.cc +++ b/generative_computing/cc/authoring/constructor.cc @@ -37,6 +37,14 @@ v0::Value CreateLabeledValue(std::string label, const v0::Value& value) { } // namespace +absl::StatusOr SmartChain::Build() { + // Enable BreakableChain. + if (num_iteration_ > 1) { + return CreateLoopChainCombo(num_iteration_, chained_ops_); + } + return CreateBasicChain(chained_ops_); +} + absl::StatusOr CreateRepeat(int num_steps, v0::Value body_fn) { v0::Value repeat_pb; v0::Intrinsic* const intrinsic_pb = repeat_pb.mutable_intrinsic(); diff --git a/generative_computing/cc/authoring/constructor.h b/generative_computing/cc/authoring/constructor.h index 9691cd63..27d590cd 100644 --- a/generative_computing/cc/authoring/constructor.h +++ b/generative_computing/cc/authoring/constructor.h @@ -16,7 +16,6 @@ limitations under the License #ifndef GENERATIVE_COMPUTING_CC_AUTHORING_CONSTRUCTOR_H_ #define GENERATIVE_COMPUTING_CC_AUTHORING_CONSTRUCTOR_H_ -#include #include #include "absl/status/statusor.h" @@ -25,6 +24,43 @@ limitations under the License namespace generative_computing { +// Used to build chains with loops and breakpoints. Will always choose the +// simpler and more efficient chain type based on chained computations. +class SmartChain { + public: + explicit SmartChain(int num_iteration = 0) : num_iteration_(num_iteration) {} + + // Enable Pipe operator op1 | op2 | op3 ... + SmartChain& operator|(const v0::Value& op) { + chained_ops_.push_back(op); + return *this; + } + + SmartChain& operator|(const absl::StatusOr& op) { + chained_ops_.push_back(op.value()); + return *this; + } + + SmartChain& operator|(SmartChain& other_chain) { + chained_ops_.push_back(other_chain.Build().value()); + return *this; + } + + // Sets number of iterations + SmartChain& operator|(int num_iteration) { + SetNumIteration(num_iteration); + return *this; + } + + absl::StatusOr Build(); + + void SetNumIteration(int i) { num_iteration_ = i; } + + private: + int num_iteration_; + std::vector chained_ops_; +}; + // Given arg_name & computation body create a Lambda that applies a computation // to the provided argument. absl::StatusOr CreateLambda(absl::string_view arg_name, diff --git a/generative_computing/cc/authoring/constructor_test.cc b/generative_computing/cc/authoring/constructor_test.cc index f4937878..38f99257 100644 --- a/generative_computing/cc/authoring/constructor_test.cc +++ b/generative_computing/cc/authoring/constructor_test.cc @@ -16,6 +16,7 @@ limitations under the License #include "generative_computing/cc/authoring/constructor.h" #include +#include #include "googletest/include/gtest/gtest.h" #include "absl/container/flat_hash_map.h" @@ -87,5 +88,24 @@ TEST(CreateWhileTest, ReturnsCorrectWhileProto) { EXPECT_EQ(kwargs.at("condition_fn").intrinsic().uri(), "regex_partial_match"); EXPECT_EQ(kwargs.at("body_fn").intrinsic().uri(), "model_inference"); } + +TEST(SmartChainTest, BuildsLoopChainComboViaPipe) { + v0::Value append_foo_fn = CreateModelInference("append_foo").value(); + v0::Value append_bar_fn = CreateModelInference("append_bar").value(); + v0::Value if_finish_then_break_fn = CreateRegexPartialMatch("FINISH").value(); + v0::Value count_to_3_append_finish_fn = + CreateModelInference("append_finish_when_counts_reaches_3").value(); + v0::Value expected_computation = + CreateLoopChainCombo(100, {append_foo_fn, if_finish_then_break_fn, + append_bar_fn, count_to_3_append_finish_fn}) + .value(); + + SmartChain smart_chain = SmartChain() | append_foo_fn | + if_finish_then_break_fn | append_bar_fn | + count_to_3_append_finish_fn | 100; + v0::Value computation = smart_chain.Build().value(); + EXPECT_EQ(computation.DebugString(), expected_computation.DebugString()); +} + } // namespace } // namespace generative_computing