|
61 | 61 | // https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133 |
62 | 62 | #include <ATen/Parallel.h> |
63 | 63 |
|
64 | | -// Default forward method to call on PyTorch modules |
65 | | -const std::string DEFAULT_MODULE_METHOD_NAME = "forward"; |
66 | 64 |
|
67 | 65 | // |
68 | 66 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API. |
@@ -113,7 +111,6 @@ class ModelState : public BackendModel { |
113 | 111 | { |
114 | 112 | return model_outputs_; |
115 | 113 | } |
116 | | - const std::string& ModuleMethodName() { return module_method_name_; } |
117 | 114 |
|
118 | 115 | private: |
119 | 116 | ModelState(TRITONBACKEND_Model* triton_model); |
@@ -156,10 +153,6 @@ class ModelState : public BackendModel { |
156 | 153 | // is specified both in the output section and state section, it indicates |
157 | 154 | // that the backend must return the output state to the client too. |
158 | 155 | std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_; |
159 | | - |
160 | | - // Method to call on PyTorch Module. |
161 | | - // Defaults to DEFAULT_MODULE_METHOD_NAME. |
162 | | - std::string module_method_name_; |
163 | 156 | }; |
164 | 157 |
|
165 | 158 | TRITONSERVER_Error* |
@@ -237,8 +230,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) |
237 | 230 | enable_inference_mode_(true), enable_cache_cleaning_(false), |
238 | 231 | enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}), |
239 | 232 | enable_jit_profiling_pair_({false, true}), |
240 | | - enable_jit_executor_pair_({false, true}), |
241 | | - module_method_name_(DEFAULT_MODULE_METHOD_NAME) |
| 233 | + enable_jit_executor_pair_({false, true}) |
242 | 234 | { |
243 | 235 | } |
244 | 236 |
|
@@ -527,30 +519,6 @@ ModelState::ParseParameters() |
527 | 519 | .c_str()); |
528 | 520 | } |
529 | 521 | } |
530 | | - |
531 | | - // If 'MODULE_METHOD_NAME' is not present in 'parameters' then |
532 | | - // 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward'). |
533 | | - std::string module_method_name = DEFAULT_MODULE_METHOD_NAME; |
534 | | - err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name); |
535 | | - if (err != nullptr) { |
536 | | - if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
537 | | - return err; |
538 | | - } else { |
539 | | - LOG_MESSAGE( |
540 | | - TRITONSERVER_LOG_INFO, |
541 | | - (std::string("module_method_name is not specified") + |
542 | | - " for model instance '" + Name() + "'") |
543 | | - .c_str()); |
544 | | - TRITONSERVER_ErrorDelete(err); |
545 | | - } |
546 | | - } else { |
547 | | - module_method_name_ = module_method_name; |
548 | | - LOG_MESSAGE( |
549 | | - TRITONSERVER_LOG_INFO, |
550 | | - (std::string("module_method_name is ") + module_method_name_ + |
551 | | - " for model instance '" + Name() + "'") |
552 | | - .c_str()); |
553 | | - } |
554 | 522 | } |
555 | 523 |
|
556 | 524 | return nullptr; |
@@ -972,20 +940,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) |
972 | 940 | // configuration specifies only those. |
973 | 941 | std::vector<std::string> allowed_inputs; |
974 | 942 |
|
975 | | - // First check if method exists in the model and throw an error if absent |
976 | | - const auto methodNameToExecute = model_state_->ModuleMethodName(); |
977 | | - const auto optionalMethodHandle = |
978 | | - torch_model_->find_method(methodNameToExecute); |
979 | | - if (!optionalMethodHandle.has_value()) { |
980 | | - return TRITONSERVER_ErrorNew( |
981 | | - TRITONSERVER_ERROR_INVALID_ARG, |
982 | | - (std::string("unable to find method '") + methodNameToExecute + |
983 | | - "' in model '" + model_path_ + "'") |
984 | | - .c_str()); |
985 | | - } |
986 | | - |
987 | | - // Get the method schema and validate the inputs |
988 | | - const torch::jit::Method& method = optionalMethodHandle.value(); |
| 943 | + const torch::jit::Method& method = torch_model_->get_method("forward"); |
989 | 944 | const auto& schema = method.function().getSchema(); |
990 | 945 | const std::vector<c10::Argument>& arguments = schema.arguments(); |
991 | 946 |
|
@@ -1628,24 +1583,18 @@ ModelInstanceState::Execute( |
1628 | 1583 | torch::NoGradGuard no_grad; |
1629 | 1584 |
|
1630 | 1585 | // If input is a dictionary, prepare dictionary from 'input_tensors'. |
1631 | | - std::string module_method_name = model_state_->ModuleMethodName(); |
1632 | | - std::vector<c10::IValue> inputs; |
1633 | 1586 | if (is_dict_input_) { |
1634 | | - c10::Dict<std::string, at::Tensor> dict; |
| 1587 | + torch::Dict<std::string, torch::Tensor> input_dict; |
1635 | 1588 | for (auto& input_index : input_index_map_) { |
1636 | 1589 | torch::jit::IValue ival = (*input_tensors)[input_index.second]; |
1637 | | - dict.insert(input_index.first, ival.toTensor()); |
| 1590 | + input_dict.insert(input_index.first, ival.toTensor()); |
1638 | 1591 | } |
1639 | | - inputs.push_back(dict); |
| 1592 | + std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict}; |
| 1593 | + model_outputs_ = torch_model_->forward(input_dict_ivalue); |
1640 | 1594 | } else { |
1641 | | - for (auto& input_tensor : *input_tensors) { |
1642 | | - inputs.push_back(input_tensor.toTensor()); |
1643 | | - } |
| 1595 | + model_outputs_ = torch_model_->forward(*input_tensors); |
1644 | 1596 | } |
1645 | 1597 |
|
1646 | | - // Actually run the method on the model. |
1647 | | - model_outputs_ = torch_model_->get_method(module_method_name)(inputs); |
1648 | | - |
1649 | 1598 | if (model_outputs_.isTuple()) { |
1650 | 1599 | auto model_outputs_tuple = model_outputs_.toTuple(); |
1651 | 1600 | size_t op_index = 0; |
|
0 commit comments