Skip to content

Commit 7af7995

Browse files
committed
Implement load cancellation flag.
1 parent 30115cf commit 7af7995

22 files changed

+230
-13
lines changed

Diff for: include/onnxruntime/core/session/onnxruntime_c_api.h

+12
Original file line numberDiff line numberDiff line change
@@ -4898,6 +4898,18 @@ struct OrtApi {
48984898
_In_ const int64_t* shape, size_t shape_len,
48994899
ONNXTensorElementDataType type,
49004900
_Outptr_ OrtValue** out);
4901+
4902+
/** \brief changes the setting of the session to cancel loading
4903+
*
4904+
* \param[in] options options instance that was passed to the session at creation time.
4905+
* \param[in] is_cancel setting this to true after model loading process was initiated will
4906+
* cancel the loading process within some reasonable time frame.
4907+
*
4908+
* \snippet{doc} snippets.dox OrtStatus Return Value
4909+
*
4910+
*/
4911+
ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options,
4912+
_In_ bool is_cancel);
49014913
};
49024914

49034915
/*

Diff for: include/onnxruntime/core/session/onnxruntime_cxx_api.h

+2
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,8 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
928928

929929
SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
930930

931+
SessionOptionsImpl& SetLoadCancellationFlag(bool value); ///< Wraps OrtApi::SessionOptionsSetLoadCancellationFlag
932+
931933
SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
932934
SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
933935

Diff for: include/onnxruntime/core/session/onnxruntime_cxx_inline.h

+6
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,12 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionM
747747
return *this;
748748
}
749749

750+
template <typename T>
751+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLoadCancellationFlag(bool value) {
752+
ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(p_, value));
753+
return *this;
754+
}
755+
750756
template <typename T>
751757
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
752758
ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));

Diff for: onnxruntime/core/framework/graph_partitioner.cc

+14
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
413413
return Status::OK();
414414
}
415415

416+
const volatile auto& load_cancellaton_flag = graph_optimizer_registry.GetLoadCancellationFlagRef();
417+
416418
// recurse into nested graphs first to partition bottom up.
417419
for (auto& node : graph.Nodes()) {
418420
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
@@ -506,6 +508,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
506508
capabilities_to_complete_fuse.push_back(std::move(capability));
507509
}
508510
}
511+
ORT_RETURN_IF(load_cancellaton_flag, "Graph partitioning is canceled due to user request.");
509512
}
510513

511514
// NOTE: if mode_ is kAssignOnly, nodes_to_compile will be empty at this point due to logic in PlaceNode
@@ -561,6 +564,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
561564
graph.FinalizeFuseSubGraph(indexed_sub_graph, *node);
562565
}
563566
}
567+
ORT_RETURN_IF(load_cancellaton_flag, "Graph partitioning is canceled due to user request.");
564568
}
565569

566570
if (!nodes_to_complete_fuse.empty()) {
@@ -641,6 +645,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
641645
return Status::OK();
642646
}
643647

648+
const volatile bool& load_cancellation_flag = graph_optimizer_registry.GetLoadCancellationFlagRef();
649+
644650
for (auto& node : graph.Nodes()) {
645651
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
646652
Graph* subgraph = entry.second;
@@ -689,6 +695,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
689695
}
690696
}
691697
}
698+
ORT_RETURN_IF(load_cancellation_flag, "AOT inlining is canceled due to user request.");
692699
}
693700

694701
// TODO: Insert version check. We need to collect all the versions
@@ -701,6 +708,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
701708
if (claimed_by_ep.count(node_index) == 0) {
702709
ORT_RETURN_IF_ERROR(graph.InlineFunction(*node));
703710
++inlined_count;
711+
ORT_RETURN_IF(load_cancellation_flag, "AOT inlining is canceled due to user request.");
704712
} else {
705713
// OpType is the same as function name.
706714
auto function_id = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType());
@@ -1032,6 +1040,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
10321040
return Status::OK();
10331041
}
10341042

1043+
ORT_RETURN_IF(IsLoadCancellationFlagSet(), "AOT inlining is canceled due to user request.");
1044+
10351045
auto& graph = model.MainGraph();
10361046
InlinedHashSet<std::string> not_inlined;
10371047
do {
@@ -1048,6 +1058,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
10481058
break;
10491059
}
10501060

1061+
ORT_RETURN_IF(IsLoadCancellationFlagSet(), "AOT inlining is canceled due to user request.");
10511062
ORT_RETURN_IF_ERROR(graph.Resolve());
10521063
} while (true);
10531064

@@ -1082,6 +1093,9 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
10821093
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified.");
10831094
}
10841095

1096+
const volatile auto& load_cancellaton_flag = graph_optimizer_registry_->GetLoadCancellationFlagRef();
1097+
ORT_RETURN_IF(load_cancellaton_flag, "Graph partitioning is canceled due user request.");
1098+
10851099
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
10861100
// fused_kernel_registry is preparing the kernels created on the fly for fused sub graph.
10871101
// It is only visible for current session.

Diff for: onnxruntime/core/framework/graph_partitioner.h

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class GraphPartitioner {
4141
Mode mode = Mode::kNormal,
4242
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;
4343

44+
bool IsLoadCancellationFlagSet() const noexcept {
45+
return graph_optimizer_registry_->GetSessionOptions().IsLoadCancellationFlagSet();
46+
}
47+
4448
#ifndef ORT_MINIMAL_BUILD
4549
/// <summary>
4650
// Ahead of Time Function inlining. The main purpose of the function is to inline as many

Diff for: onnxruntime/core/framework/session_options.h

+11
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,17 @@ struct SessionOptions {
181181
void AddCustomOpLibraryHandle(PathString library_name, void* library_handle);
182182
#endif
183183

184+
// Load cancellation flag is necessary to be within shared memory as session_options are
185+
std::shared_ptr<bool> load_cancellation_flag = std::make_shared<bool>(false);
186+
187+
bool IsLoadCancellationFlagSet() const noexcept {
188+
return *load_cancellation_flag;
189+
}
190+
191+
const bool& GetLoadCancellationFlagRef() const noexcept {
192+
return *load_cancellation_flag;
193+
}
194+
184195
// User specified logging func and param
185196
OrtLoggingFunction user_logging_function = nullptr;
186197
void* user_logging_param = nullptr;

Diff for: onnxruntime/core/framework/session_state.cc

+5
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,8 @@ Status SessionState::PrepackConstantInitializedTensors(
585585
constant_initialized_tensors.erase(ort_value_idx);
586586
}
587587
}
588+
ORT_RETURN_IF(sess_options_.IsLoadCancellationFlagSet(),
589+
"Weight pre-packing was canceled due to user request.");
588590
}
589591
// stop searching in 2 cases:
590592
// 1. value is not from OuterScope
@@ -1528,6 +1530,9 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
15281530
p_seq_exec_plan_);
15291531
ORT_RETURN_IF_ERROR(status);
15301532

1533+
ORT_RETURN_IF(session_options.IsLoadCancellationFlagSet(),
1534+
"SessionState finalize is canceled due to user request");
1535+
15311536
// Record the allocation plan
15321537

15331538
// Uncomment the below to dump the allocation plan to std::cout

Diff for: onnxruntime/core/framework/session_state_utils.cc

+4
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ common::Status SaveInitializedTensors(
378378

379379
// 3. create weight tensors based on weights buffer
380380
for (const auto& entry : id_to_initialized_tensor) {
381+
// We check for cancelleation for every initializer since mapping from disk can be costly
382+
ORT_RETURN_IF(session_options.IsLoadCancellationFlagSet(),
383+
"Saving session state weights is canceled due to user request.");
384+
381385
int ort_value_index = entry.first;
382386
const std::string& name = entry.second->name();
383387

Diff for: onnxruntime/core/graph/graph.cc

+11
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,10 @@ Graph::Graph(const Model& owning_model,
12681268
#endif
12691269
}
12701270

1271+
if (owning_model_.IsCancellationFalgSet()) {
1272+
ORT_THROW("Graph loading canceled due to user request.");
1273+
}
1274+
12711275
// Remove constant nodes as they're replaced with initializers above.
12721276
const gsl::not_null<RepeatedPtrField<NodeProto>*> graph_mutable_nodes{graph_proto_->mutable_node()};
12731277
graph_mutable_nodes->erase(
@@ -1300,6 +1304,9 @@ Graph::Graph(const Model& owning_model,
13001304
delete graph_proto_->mutable_sparse_initializer()->ReleaseCleared();
13011305
}
13021306
#endif
1307+
if (owning_model_.IsCancellationFalgSet()) {
1308+
ORT_THROW("Graph loading canceled due to user request.");
1309+
}
13031310
}
13041311
#endif
13051312

@@ -1365,6 +1372,10 @@ Graph::Graph(const Model& owning_model,
13651372
}
13661373
}
13671374

1375+
if (owning_model_.IsCancellationFalgSet()) {
1376+
ORT_THROW("Graph loading canceled due to user request.");
1377+
}
1378+
13681379
for (auto& graph_output : graph_proto_->output()) {
13691380
if (utils::HasName(graph_output) && utils::HasType(graph_output)) {
13701381
auto& name = graph_output.name();

Diff for: onnxruntime/core/graph/model.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Model::Model(const std::string& graph_name,
8282
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_local_functions,
8383
const logging::Logger& logger,
8484
const ModelOptions& options)
85-
: model_path_(model_path) {
85+
: model_path_(model_path), load_cancellation_flag_(options.load_cancellation_flag) {
8686
model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
8787
model_proto_.mutable_graph()->set_name(graph_name);
8888
model_metadata_ = model_metadata;
@@ -161,7 +161,7 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path,
161161
Model::Model(ModelProto&& model_proto, const PathString& model_path,
162162
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
163163
const logging::Logger& logger, const ModelOptions& options)
164-
: model_path_(model_path) {
164+
: model_path_(model_path), load_cancellation_flag_(options.load_cancellation_flag) {
165165
if (!utils::HasGraph(model_proto)) {
166166
ORT_THROW("ModelProto does not have a graph.");
167167
}

Diff for: onnxruntime/core/graph/model.h

+13-2
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@ struct ModelOptions {
3838
// be returned.
3939
bool strict_shape_type_inference;
4040

41-
ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference)
41+
const bool* load_cancellation_flag = nullptr;
42+
43+
ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference,
44+
const bool* load_cancellation_flag = nullptr)
4245
: allow_released_opsets_only(allow_released_opsets_only),
43-
strict_shape_type_inference(strict_shape_type_inference) {}
46+
strict_shape_type_inference(strict_shape_type_inference),
47+
load_cancellation_flag(load_cancellation_flag) {}
4448

4549
ModelOptions() : ModelOptions(true, false) {}
4650
};
@@ -143,6 +147,11 @@ class Model {
143147

144148
const NodeHashMap<std::string, std::unique_ptr<FunctionTemplate>>& GetModelLocalFunctionTemplates() const;
145149

150+
// Check for load cancellation.
151+
bool IsCancellationFalgSet() const noexcept {
152+
return load_cancellation_flag_ && *load_cancellation_flag_;
153+
}
154+
146155
#else
147156
// Get model's IR version.
148157
// Return <kNoVersion> if not specified.
@@ -343,5 +352,7 @@ class Model {
343352

344353
// Main graph of the model.
345354
std::unique_ptr<Graph> graph_;
355+
356+
const bool* load_cancellation_flag_ = nullptr;
346357
};
347358
} // namespace onnxruntime

Diff for: onnxruntime/core/optimizer/graph_optimizer_registry.h

+10
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ class GraphOptimizerRegistry {
5555
*/
5656
const onnxruntime::SessionOptions& GetSessionOptions() const { return *session_options_; }
5757

58+
/**
59+
* Get Load Cancellation Flag.
60+
* This flag is used to cancel the loading of the model if it takes too long.
61+
* It is set to true when the user cancels the loading process.
62+
* The flag is shared across all threads, so it is read-only
63+
*/
64+
const bool& GetLoadCancellationFlagRef() const {
65+
return session_options_->GetLoadCancellationFlagRef();
66+
}
67+
5868
/**
5969
* Get Logger.
6070
*/

Diff for: onnxruntime/core/optimizer/graph_transformer_mgr.cc

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor
3535
bool modified = false;
3636
ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger));
3737
graph_changed = graph_changed || modified;
38+
ORT_RETURN_IF(IsLoadCancellationFlagSet(),
39+
"Graph transformation canceled due to user request.");
3840
}
3941
if (!graph_changed) {
4042
break;

Diff for: onnxruntime/core/optimizer/graph_transformer_mgr.h

+11
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ class GraphTransformerManager {
2424
// Get the maximum number of graph transformation steps
2525
common::Status GetSteps(unsigned& steps) const;
2626

27+
// Set the cancellation flag ptr from session_options
28+
void SetLoadCancellationFlagRef(const bool& load_cancellation_flag) {
29+
load_cancellation_flag_ = &load_cancellation_flag;
30+
}
31+
32+
// Get the cancellation flag ptr
33+
bool IsLoadCancellationFlagSet() const noexcept {
34+
return load_cancellation_flag_ && *load_cancellation_flag_;
35+
}
36+
2737
// Register a transformer with a level.
2838
common::Status Register(std::unique_ptr<GraphTransformer> transformer, TransformerLevel level);
2939

@@ -38,5 +48,6 @@ class GraphTransformerManager {
3848

3949
InlinedHashMap<TransformerLevel, InlinedVector<std::unique_ptr<GraphTransformer>>> level_to_transformer_map_;
4050
InlinedHashMap<std::string, GraphTransformer*> transformers_info_;
51+
const bool* load_cancellation_flag_ = nullptr;
4152
};
4253
} // namespace onnxruntime

Diff for: onnxruntime/core/session/abi_session_options.cc

+6
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,9 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions*
340340
return nullptr;
341341
API_IMPL_END
342342
}
343+
344+
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options,
345+
_In_ bool is_cancel) {
346+
*options->value.load_cancellation_flag = is_cancel;
347+
return nullptr;
348+
}

0 commit comments

Comments
 (0)