1
1
#include " llm.h"
2
+
2
3
#include < HalideRuntime.h>
4
+
5
+ #include < chrono>
3
6
#include < iomanip>
4
7
#include < iostream>
5
8
@@ -20,6 +23,49 @@ ABSL_FLAG(int, max_tokens, 512,
20
23
" Maximum number of input and output tokens. This value needs to be "
21
24
" at least larger than the number of input tokens." );
22
25
26
+ ABSL_FLAG (bool , show_timing, false ,
27
+ " Show timing for operations." );
28
+
29
+ namespace {
30
+
31
+ // Prefer high_resolution_clock, but only if it's steady...
32
+ template <bool HighResIsSteady = std::chrono::high_resolution_clock::is_steady>
33
+ struct SteadyClock {
34
+ using type = std::chrono::high_resolution_clock;
35
+ };
36
+
37
+ // ...otherwise use steady_clock.
38
+ template <>
39
+ struct SteadyClock <false > {
40
+ using type = std::chrono::steady_clock;
41
+ };
42
+
43
+
44
+ struct TimingScope {
45
+ TimingScope (const char *name, int iterations = 1 ) : name(name), iterations(iterations) {
46
+ start = SteadyClock<>::type::now ();
47
+ }
48
+
49
+ ~TimingScope () {
50
+ if (absl::GetFlag (FLAGS_show_timing)) {
51
+ SteadyClock<>::type::time_point end = SteadyClock<>::type::now ();
52
+ double secs = std::chrono::duration_cast<std::chrono::duration<double >>(end - start).count ();
53
+ std::cerr << name << " : took " << secs << " s" ;
54
+ if (iterations != 1 ) {
55
+ std::cerr << " " << secs / iterations << " s per iteration.\n " ;
56
+ } else {
57
+ std::cerr << " \n " ;
58
+ }
59
+ }
60
+ }
61
+
62
+ std::string name;
63
+ int iterations;
64
+ SteadyClock<>::type::time_point start;
65
+ };
66
+
67
+ }
68
+
23
69
int main (int argc, char *argv[]) {
24
70
absl::ParseCommandLine (argc, argv);
25
71
@@ -30,6 +76,7 @@ int main(int argc, char *argv[]) {
30
76
31
77
sentencepiece::SentencePieceProcessor tokenizer;
32
78
{
79
+ TimingScope load_tokenizer (" Loading tokenizer" );
33
80
auto result = tokenizer.Load (tokenizer_path);
34
81
if (!result.ok ()) {
35
82
std::cerr << result.message ();
@@ -49,66 +96,87 @@ int main(int argc, char *argv[]) {
49
96
auto result = tokenizer.Encode (bracketed_prompt, &prompt_tokens);
50
97
}
51
98
52
- std::cerr << " Loading LLM params.\n " ;
53
- auto p = hallmark::LoadLlmParams (model_path);
54
- if (!p.ok ()) {
55
- std::cerr << p.status () << " \n " ;
56
- return 1 ;
99
+ hallmark::LlmParams llm_params;
100
+ {
101
+ TimingScope load_tokenizer (" Loading LLM params" );
102
+ auto p = hallmark::LoadLlmParams (model_path);
103
+ if (!p.ok ()) {
104
+ std::cerr << p.status () << " \n " ;
105
+ return 1 ;
106
+ }
107
+ llm_params = std::move (p.value ());
57
108
}
58
- auto llm_params = std::move (p.value ());
59
109
llm_params.seq_size_T = max_tokens;
60
110
61
- std::cerr << " Loading LLM weights.\n " ;
62
- auto w = hallmark::LoadLlmWeights (model_path, llm_params);
63
- if (!w.ok ()) {
64
- std::cerr << w.status () << " \n " ;
65
- return 1 ;
111
+ hallmark::LlmWeights llm_weights;
112
+ {
113
+ TimingScope load_tokenizer (" Loading LLM params" );
114
+ auto w = hallmark::LoadLlmWeights (model_path, llm_params);
115
+ if (!w.ok ()) {
116
+ std::cerr << w.status () << " \n " ;
117
+ return 1 ;
118
+ }
119
+ llm_weights = std::move (w.value ());
66
120
}
67
- auto llm_weights = std::move (w.value ());
68
121
69
- std::cerr << " Creating LLM.\n " ;
70
- auto l = hallmark::Llm::CreateLlm (llm_weights, llm_params);
71
- if (!l.ok ()) {
72
- std::cerr << l.status () << " \n " ;
73
- return 2 ;
122
+ std::unique_ptr<hallmark::Llm> llm;
123
+ {
124
+ TimingScope load_tokenizer (" Creating LLM" );
125
+ auto l = hallmark::Llm::CreateLlm (llm_weights, llm_params);
126
+ if (!l.ok ()) {
127
+ std::cerr << l.status () << " \n " ;
128
+ return 2 ;
129
+ }
130
+ llm = std::move (l.value ());
74
131
}
75
- auto llm = std::move (l.value ());
76
132
77
133
if (!llm->Reset ().ok ()) {
78
134
std::cerr << " Reset fails\n " ;
79
135
return 3 ;
80
136
}
81
- if (!llm->InitAttentionMaskValues (llm_params.seq_size_T ).ok ()) {
82
- std::cerr << " InitAttentionMaskValues fails\n " ;
83
- return 4 ;
137
+ {
138
+ TimingScope load_tokenizer (" Init attention mask" );
139
+ if (!llm->InitAttentionMaskValues (llm_params.seq_size_T ).ok ()) {
140
+ std::cerr << " InitAttentionMaskValues fails\n " ;
141
+ return 4 ;
142
+ }
84
143
}
85
144
86
- if (!llm->InitInputTokens (prompt_tokens).ok ()) {
87
- std::cerr << " InitInputTokens fails\n " ;
88
- return 1 ;
145
+ {
146
+ TimingScope load_tokenizer (" Init input tokens" , prompt_tokens.size ());
147
+ if (!llm->InitInputTokens (prompt_tokens).ok ()) {
148
+ std::cerr << " InitInputTokens fails\n " ;
149
+ return 1 ;
150
+ }
89
151
}
90
152
91
153
std::cout << prompt << " \n " ;
92
154
93
- for (int token = prompt_tokens.size (); token < max_tokens; token++) {
155
+ {
156
+ TimingScope generate (" \n Generate tokens" , max_tokens);
94
157
std::vector<int > output_tokens;
95
- if (!llm->GetNextToken (&output_tokens).ok ()) {
96
- std::cerr << " GetNextToken fails\n " ;
97
- return 6 ;
98
- }
99
- if (output_tokens.empty ()) {
100
- std::cerr << " Empty result from GetNextToken.\n " ;
101
- }
102
- std::string decoded_tokens;
103
- if (!tokenizer.Decode (output_tokens, &decoded_tokens).ok ()) {
104
- std::cerr << " Decode fails\n " ;
105
- return 7 ;
106
- }
107
- if (decoded_tokens.empty ()) {
108
- std::cout << " _" ;
158
+ for (int token = prompt_tokens.size (); token < max_tokens - 2 ; token += output_tokens.size ()) {
159
+ output_tokens.clear ();
160
+ if (!llm->GetNextToken (&output_tokens).ok ()) {
161
+ std::cerr << " GetNextToken fails\n " ;
162
+ return 6 ;
163
+ }
164
+ if (output_tokens.empty ()) {
165
+ std::cerr << " Empty result from GetNextToken.\n " ;
166
+ } else if (output_tokens.size () > 1 ) {
167
+ std::cerr << " More than one token returned from GetNextToken token " << token << " .\n " ;
168
+ }
169
+ std::string decoded_tokens;
170
+ if (!tokenizer.Decode (output_tokens, &decoded_tokens).ok ()) {
171
+ std::cerr << " Decode fails\n " ;
172
+ return 7 ;
173
+ }
174
+ if (decoded_tokens.empty ()) {
175
+ std::cout << " _" ;
176
+ }
177
+ std::cout << decoded_tokens;
178
+ std::cout.flush ();
109
179
}
110
- std::cout << decoded_tokens;
111
- std::cout.flush ();
112
180
}
113
181
114
182
return 0 ;
0 commit comments