Skip to content

Commit 0e99743

Browse files
author
Zalman Stern
committed
Add show_timing flag to runner to show how long various operations
take and give a per token time.
1 parent 806d3c6 commit 0e99743

File tree

1 file changed

+109
-41
lines changed

1 file changed

+109
-41
lines changed

apps/hallmark/src/llm_runner.cpp

+109-41
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "llm.h"
2+
23
#include <HalideRuntime.h>
4+
5+
#include <chrono>
36
#include <iomanip>
47
#include <iostream>
58

@@ -20,6 +23,49 @@ ABSL_FLAG(int, max_tokens, 512,
2023
"Maximum number of input and output tokens. This value needs to be "
2124
"at least larger than the number of input tokens.");
2225

26+
ABSL_FLAG(bool, show_timing, false,
27+
"Show timing for operations.");
28+
29+
namespace {
30+
31+
// Prefer high_resolution_clock, but only if it's steady...
32+
template<bool HighResIsSteady = std::chrono::high_resolution_clock::is_steady>
33+
struct SteadyClock {
34+
using type = std::chrono::high_resolution_clock;
35+
};
36+
37+
// ...otherwise use steady_clock.
38+
template<>
39+
struct SteadyClock<false> {
40+
using type = std::chrono::steady_clock;
41+
};
42+
43+
44+
struct TimingScope {
45+
TimingScope(const char *name, int iterations = 1) : name(name), iterations(iterations) {
46+
start = SteadyClock<>::type::now();
47+
}
48+
49+
~TimingScope() {
50+
if (absl::GetFlag(FLAGS_show_timing)) {
51+
SteadyClock<>::type::time_point end = SteadyClock<>::type::now();
52+
double secs = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
53+
std::cerr << name << ": took " << secs << "s";
54+
if (iterations != 1) {
55+
std::cerr << " " << secs / iterations << "s per iteration.\n";
56+
} else {
57+
std::cerr << "\n";
58+
}
59+
}
60+
}
61+
62+
std::string name;
63+
int iterations;
64+
SteadyClock<>::type::time_point start;
65+
};
66+
67+
}
68+
2369
int main(int argc, char *argv[]) {
2470
absl::ParseCommandLine(argc, argv);
2571

@@ -30,6 +76,7 @@ int main(int argc, char *argv[]) {
3076

3177
sentencepiece::SentencePieceProcessor tokenizer;
3278
{
79+
TimingScope load_tokenizer("Loading tokenizer");
3380
auto result = tokenizer.Load(tokenizer_path);
3481
if (!result.ok()) {
3582
std::cerr << result.message();
@@ -49,66 +96,87 @@ int main(int argc, char *argv[]) {
4996
auto result = tokenizer.Encode(bracketed_prompt, &prompt_tokens);
5097
}
5198

52-
std::cerr << "Loading LLM params.\n";
53-
auto p = hallmark::LoadLlmParams(model_path);
54-
if (!p.ok()) {
55-
std::cerr << p.status() << "\n";
56-
return 1;
99+
hallmark::LlmParams llm_params;
100+
{
101+
TimingScope load_tokenizer("Loading LLM params");
102+
auto p = hallmark::LoadLlmParams(model_path);
103+
if (!p.ok()) {
104+
std::cerr << p.status() << "\n";
105+
return 1;
106+
}
107+
llm_params = std::move(p.value());
57108
}
58-
auto llm_params = std::move(p.value());
59109
llm_params.seq_size_T = max_tokens;
60110

61-
std::cerr << "Loading LLM weights.\n";
62-
auto w = hallmark::LoadLlmWeights(model_path, llm_params);
63-
if (!w.ok()) {
64-
std::cerr << w.status() << "\n";
65-
return 1;
111+
hallmark::LlmWeights llm_weights;
112+
{
113+
TimingScope load_tokenizer("Loading LLM params");
114+
auto w = hallmark::LoadLlmWeights(model_path, llm_params);
115+
if (!w.ok()) {
116+
std::cerr << w.status() << "\n";
117+
return 1;
118+
}
119+
llm_weights = std::move(w.value());
66120
}
67-
auto llm_weights = std::move(w.value());
68121

69-
std::cerr << "Creating LLM.\n";
70-
auto l = hallmark::Llm::CreateLlm(llm_weights, llm_params);
71-
if (!l.ok()) {
72-
std::cerr << l.status() << "\n";
73-
return 2;
122+
std::unique_ptr<hallmark::Llm> llm;
123+
{
124+
TimingScope load_tokenizer("Creating LLM");
125+
auto l = hallmark::Llm::CreateLlm(llm_weights, llm_params);
126+
if (!l.ok()) {
127+
std::cerr << l.status() << "\n";
128+
return 2;
129+
}
130+
llm = std::move(l.value());
74131
}
75-
auto llm = std::move(l.value());
76132

77133
if (!llm->Reset().ok()) {
78134
std::cerr << "Reset fails\n";
79135
return 3;
80136
}
81-
if (!llm->InitAttentionMaskValues(llm_params.seq_size_T).ok()) {
82-
std::cerr << "InitAttentionMaskValues fails\n";
83-
return 4;
137+
{
138+
TimingScope load_tokenizer("Init attention mask");
139+
if (!llm->InitAttentionMaskValues(llm_params.seq_size_T).ok()) {
140+
std::cerr << "InitAttentionMaskValues fails\n";
141+
return 4;
142+
}
84143
}
85144

86-
if (!llm->InitInputTokens(prompt_tokens).ok()) {
87-
std::cerr << "InitInputTokens fails\n";
88-
return 1;
145+
{
146+
TimingScope load_tokenizer("Init input tokens", prompt_tokens.size());
147+
if (!llm->InitInputTokens(prompt_tokens).ok()) {
148+
std::cerr << "InitInputTokens fails\n";
149+
return 1;
150+
}
89151
}
90152

91153
std::cout << prompt << "\n";
92154

93-
for (int token = prompt_tokens.size(); token < max_tokens; token++) {
155+
{
156+
TimingScope generate("\nGenerate tokens", max_tokens);
94157
std::vector<int> output_tokens;
95-
if (!llm->GetNextToken(&output_tokens).ok()) {
96-
std::cerr << "GetNextToken fails\n";
97-
return 6;
98-
}
99-
if (output_tokens.empty()) {
100-
std::cerr << "Empty result from GetNextToken.\n";
101-
}
102-
std::string decoded_tokens;
103-
if (!tokenizer.Decode(output_tokens, &decoded_tokens).ok()) {
104-
std::cerr << "Decode fails\n";
105-
return 7;
106-
}
107-
if (decoded_tokens.empty()) {
108-
std::cout << "_";
158+
for (int token = prompt_tokens.size(); token < max_tokens - 2; token += output_tokens.size()) {
159+
output_tokens.clear();
160+
if (!llm->GetNextToken(&output_tokens).ok()) {
161+
std::cerr << "GetNextToken fails\n";
162+
return 6;
163+
}
164+
if (output_tokens.empty()) {
165+
std::cerr << "Empty result from GetNextToken.\n";
166+
} else if (output_tokens.size() > 1) {
167+
std::cerr << "More than one token returned from GetNextToken token " << token << ".\n";
168+
}
169+
std::string decoded_tokens;
170+
if (!tokenizer.Decode(output_tokens, &decoded_tokens).ok()) {
171+
std::cerr << "Decode fails\n";
172+
return 7;
173+
}
174+
if (decoded_tokens.empty()) {
175+
std::cout << "_";
176+
}
177+
std::cout << decoded_tokens;
178+
std::cout.flush();
109179
}
110-
std::cout << decoded_tokens;
111-
std::cout.flush();
112180
}
113181

114182
return 0;

0 commit comments

Comments
 (0)