diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 737c69072b64..8a4bef2da984 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -83,10 +83,19 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective // load forced_splits file if (!config->forcedsplits_filename.empty()) { std::ifstream forced_splits_file(config->forcedsplits_filename.c_str()); - std::stringstream buffer; - buffer << forced_splits_file.rdbuf(); - std::string err; - forced_splits_json_ = Json::parse(buffer.str(), &err); + if (!forced_splits_file.good()) { + Log::Warning("Forced splits file '%s' does not exist. Forced splits will be ignored.", + config->forcedsplits_filename.c_str()); + } else { + std::stringstream buffer; + buffer << forced_splits_file.rdbuf(); + std::string err; + forced_splits_json_ = Json::parse(buffer.str(), &err); + if (!err.empty()) { + Log::Fatal("Failed to parse forced splits file '%s': %s", + config->forcedsplits_filename.c_str(), err.c_str()); + } + } } objective_function_ = objective_function; @@ -823,13 +832,23 @@ void GBDT::ResetConfig(const Config* config) { if (config_.get() != nullptr && config_->forcedsplits_filename != new_config->forcedsplits_filename) { // load forced_splits file if (!new_config->forcedsplits_filename.empty()) { - std::ifstream forced_splits_file( - new_config->forcedsplits_filename.c_str()); - std::stringstream buffer; - buffer << forced_splits_file.rdbuf(); - std::string err; - forced_splits_json_ = Json::parse(buffer.str(), &err); - tree_learner_->SetForcedSplit(&forced_splits_json_); + std::ifstream forced_splits_file(new_config->forcedsplits_filename.c_str()); + if (!forced_splits_file.good()) { + Log::Warning("Forced splits file '%s' does not exist. Forced splits will be ignored.", + new_config->forcedsplits_filename.c_str()); + forced_splits_json_ = Json(); + tree_learner_->SetForcedSplit(nullptr); + } else { + std::stringstream buffer; + buffer << forced_splits_file.rdbuf(); + std::string err; + forced_splits_json_ = Json::parse(buffer.str(), &err); + if (!err.empty()) { + Log::Fatal("Failed to parse forced splits file '%s': %s", + new_config->forcedsplits_filename.c_str(), err.c_str()); + } + tree_learner_->SetForcedSplit(&forced_splits_json_); + } } else { forced_splits_json_ = Json(); tree_learner_->SetForcedSplit(nullptr);