Skip to content

Commit d7651f4

Browse files
Backport #7315 and #7380 to release/15.x branch (#7383)
* Make Callable::call_argv_fast public (#7315) * Make Callable::call_argv_fast public * Add rough specification of the calling convention * Fix a typo * Add Callable default ctor + `defined()` method (#7380) * Add Callable default ctor + `defined()` method This allows it to behave like * Add user_assert + test --------- Co-authored-by: Tom Westerhout <[email protected]>
1 parent 4ce5009 commit d7651f4

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

src/Callable.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ void destroy<CallableContents>(const CallableContents *p) {
4747
} // namespace Internal
4848

4949
Callable::Callable()
50-
: contents(new CallableContents) {
50+
: contents(nullptr) {
51+
}
52+
53+
bool Callable::defined() const {
54+
return contents.defined();
5155
}
5256

5357
Callable::Callable(const std::string &name,
@@ -136,6 +140,8 @@ Callable::FailureFn Callable::check_qcci(size_t argc, const QuickCallCheckInfo *
136140
}
137141

138142
Callable::FailureFn Callable::check_fcci(size_t argc, const FullCallCheckInfo *actual_fcci) const {
143+
user_assert(defined()) << "Cannot call() a default-constructed Callable.";
144+
139145
// Lazily create full_call_check_info upon the first call to make_std_function().
140146
if (contents->full_call_check_info.empty()) {
141147
contents->full_call_check_info.reserve(contents->jit_cache.arguments.size());
@@ -197,6 +203,8 @@ Callable::FailureFn Callable::check_fcci(size_t argc, const FullCallCheckInfo *a
197203
}
198204

199205
int Callable::call_argv_checked(size_t argc, const void *const *argv, const QuickCallCheckInfo *actual_qcci) const {
206+
user_assert(defined()) << "Cannot call() a default-constructed Callable.";
207+
200208
// It's *essential* we call this for safety.
201209
const auto failure_fn = check_qcci(argc, actual_qcci);
202210
if (failure_fn) {

src/Callable.h

+25-2
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,13 @@ class Callable {
273273
}
274274
};
275275

276-
Callable();
277276
Callable(const std::string &name,
278277
const JITHandlers &jit_handlers,
279278
const std::map<std::string, JITExtern> &jit_externs,
280279
Internal::JITCache &&jit_cache);
281280

282281
// Note that the first entry in argv must always be a JITUserContext*.
283282
int call_argv_checked(size_t argc, const void *const *argv, const QuickCallCheckInfo *actual_cci) const;
284-
int call_argv_fast(size_t argc, const void *const *argv) const;
285283

286284
using FailureFn = std::function<int(JITUserContext *)>;
287285

@@ -304,6 +302,13 @@ class Callable {
304302
const std::vector<Argument> &arguments() const;
305303

306304
public:
305+
/** Construct a default Callable. This is not usable (trying to call it will fail).
306+
* The defined() method will return false. */
307+
Callable();
308+
309+
/** Return true if the Callable is well-defined and usable, false if it is a default-constructed empty Callable. */
310+
bool defined() const;
311+
307312
template<typename... Args>
308313
HALIDE_FUNCTION_ATTRS int
309314
operator()(JITUserContext *context, Args &&...args) const {
@@ -380,6 +385,24 @@ class Callable {
380385
};
381386
}
382387
}
388+
389+
/** Unsafe low-overhead way of invoking the Callable.
390+
*
391+
* This function relies on the same calling convention as the argv-based
392+
* functions generated for ahead-of-time compiled Halide pilelines.
393+
*
394+
* Very rough specifications of the calling convention (but check the source
395+
* code to be sure):
396+
*
397+
* * Arguments are passed in the same order as they appear in the C
398+
* function argument list.
399+
* * The first entry in argv must always be a JITUserContext*. Please,
400+
* note that this means that argv[0] actually contains JITUserContext**.
401+
* * All scalar arguments are passed by pointer, not by value, regardless of size.
402+
* * All buffer arguments (input or output) are passed as halide_buffer_t*.
403+
*
404+
*/
405+
int call_argv_fast(size_t argc, const void *const *argv) const;
383406
};
384407

385408
} // namespace Halide

test/correctness/callable.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ HalideExtern_2(float, my_extern_func, int, float);
4444
int main(int argc, char **argv) {
4545
const Target t = get_jit_target_from_environment();
4646

47+
{
48+
// Check that we can default-construct a Callable.
49+
Callable c;
50+
assert(!c.defined());
51+
52+
// This will assert-fail.
53+
// c(0,1,2);
54+
}
55+
4756
{
4857
Param<int32_t> p_int(42);
4958
Param<float> p_float(1.0f);

0 commit comments

Comments
 (0)