2525//  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727#include  < stdint.h> 
28+ 
2829#include  < exception> 
30+ 
2931#include  " libtorch_utils.h" 
3032#include  " triton/backend/backend_common.h" 
3133#include  " triton/backend/backend_input_collector.h" 
5355#include  < cuda_runtime_api.h> 
5456#endif   //  TRITON_ENABLE_GPU
5557
58+ //  Default forward method to call on PyTorch modules
59+ const  std::string DEFAULT_MODULE_METHOD_NAME = " forward" 
60+ 
5661// 
5762//  PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
5863// 
@@ -103,6 +108,7 @@ class ModelState : public BackendModel {
103108
104109  bool  EnabledWeightSharing () { return  enable_weight_sharing_; }
105110  const  std::vector<std::string>& ModelOutputs () { return  output_names_; }
111+   const  std::string& ModuleMethodName () { return  module_method_name_; }
106112
107113 private: 
108114  ModelState (TRITONBACKEND_Model* triton_model);
@@ -145,6 +151,10 @@ class ModelState : public BackendModel {
145151  //  List of all the outputs specified in the output section of model
146152  //  configuration.
147153  std::vector<std::string> output_names_;
154+ 
155+   //  Method to call on PyTorch Module.
156+   //  Defaults to DEFAULT_MODULE_METHOD_NAME.
157+   std::string module_method_name_;
148158};
149159
150160TRITONSERVER_Error*
@@ -180,7 +190,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180190      enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
181191      enable_jit_profiling_pair_({false , true }),
182192      enable_jit_executor_pair_({false , true }),
183-       enable_nvfuser_pair_({false , false })
193+       enable_nvfuser_pair_({false , false }),
194+       module_method_name_(DEFAULT_MODULE_METHOD_NAME)
184195{
185196  output_names_.clear ();
186197
@@ -454,6 +465,30 @@ ModelState::ParseParameters()
454465                                  "  for model instance '" Name () + " '" 
455466                                     .c_str ());
456467    }
468+ 
469+     //  If 'MODULE_METHOD_NAME' is not present in 'parameters' then
470+     //  'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward').
471+     std::string module_method_name = DEFAULT_MODULE_METHOD_NAME;
472+     err = GetParameterValue (params, " MODULE_METHOD_NAME" 
473+     if  (err != nullptr ) {
474+       if  (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
475+         return  err;
476+       } else  {
477+         LOG_MESSAGE (
478+             TRITONSERVER_LOG_INFO,
479+             (std::string (" module_method_name is not specified" 
480+              "  for model instance '" Name () + " '" 
481+                 .c_str ());
482+         TRITONSERVER_ErrorDelete (err);
483+       }
484+     } else  {
485+       module_method_name_ = module_method_name;
486+       LOG_MESSAGE (
487+           TRITONSERVER_LOG_INFO,
488+           (std::string (" module_method_name is " 
489+            "  for model instance '" Name () + " '" 
490+               .c_str ());
491+     }
457492  }
458493
459494  return  nullptr ;
@@ -764,7 +799,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764799  //  configuration specifies only those.
765800  std::vector<std::string> allowed_inputs;
766801
767-   const  torch::jit::Method& method = torch_model_->get_method (" forward" 
802+   const  torch::jit::Method& method =
803+       torch_model_->get_method (model_state_->ModuleMethodName ());
768804  const  auto & schema = method.function ().getSchema ();
769805  const  std::vector<c10::Argument>& arguments = schema.arguments ();
770806
@@ -1312,30 +1348,36 @@ ModelInstanceState::Execute(
13121348        torch::jit::overrideCanFuseOnCPU (false );
13131349        torch::jit::overrideCanFuseOnGPU (false );
13141350        torch::jit::setTensorExprFuserEnabled (false );
1315- 	 torch::jit::fuser::cuda::setEnabled (true );
1351+          torch::jit::fuser::cuda::setEnabled (true );
13161352      } else  {
13171353        torch::jit::overrideCanFuseOnCPU (true );
13181354        torch::jit::overrideCanFuseOnGPU (true );
13191355        torch::jit::setTensorExprFuserEnabled (true );
1320- 	 torch::jit::fuser::cuda::setEnabled (false );
1356+          torch::jit::fuser::cuda::setEnabled (false );
13211357      }
13221358    }
13231359
13241360    torch::NoGradGuard no_grad;
13251361
13261362    //  If input is a dictionary, prepare dictionary from 'input_tensors'.
1363+     std::string module_method_name = model_state_->ModuleMethodName ();
1364+     std::vector<c10::IValue> inputs;
13271365    if  (is_dict_input_) {
1328-       torch ::Dict<std::string, torch ::Tensor> input_dict ;
1366+       c10 ::Dict<std::string, at ::Tensor> dict ;
13291367      for  (auto & input_index : input_index_map_) {
13301368        torch::jit::IValue ival = (*input_tensors)[input_index.second ];
1331-         input_dict .insert (input_index.first , ival.toTensor ());
1369+         dict .insert (input_index.first , ival.toTensor ());
13321370      }
1333-       std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334-       model_outputs_ = torch_model_->forward (input_dict_ivalue);
1371+       inputs.push_back (dict);
13351372    } else  {
1336-       model_outputs_ = torch_model_->forward (*input_tensors);
1373+       for  (auto & input_tensor : *input_tensors) {
1374+         inputs.push_back (input_tensor.toTensor ());
1375+       }
13371376    }
13381377
1378+     //  Actually run the method on the model.
1379+     model_outputs_ = torch_model_->get_method (module_method_name)(inputs);
1380+ 
13391381    if  (model_outputs_.isTuple ()) {
13401382      auto  model_outputs_tuple = model_outputs_.toTuple ();
13411383      size_t  op_index = 0 ;
@@ -1761,9 +1803,9 @@ ModelInstanceState::SetInputTensors(
17611803
17621804        batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
17631805      }
1764-     }
1765-     else  { 
1766-       batchn_shape =  std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1806+     }  else  { 
1807+       batchn_shape = 
1808+            std::vector<int64_t >(input_shape, input_shape + input_dims_count);
17671809      if  (supports_batching_) {
17681810        batchn_shape[0 ] = total_batch_size;
17691811      }
@@ -1772,8 +1814,8 @@ ModelInstanceState::SetInputTensors(
17721814    //  The input must be in contiguous CPU/GPU memory.
17731815    std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
17741816    if  (device_.is_cpu ()) {
1775-       alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED,  0 }, 
1776-                            {TRITONSERVER_MEMORY_CPU, 0 }};
1817+       alloc_perference = {
1818+           {TRITONSERVER_MEMORY_CPU_PINNED,  0 },  {TRITONSERVER_MEMORY_CPU, 0 }};
17771819    } else  {
17781820      alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
17791821    }
@@ -1887,9 +1929,11 @@ ModelInstanceState::ReadOutputTensors(
18871929
18881930      //  Output tensors may not reside on the same device as model
18891931      torch::Device tensor_device = output_flat.device ();
1890-       const  auto  memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1891-                                                   : TRITONSERVER_MEMORY_GPU;
1892-       const  auto  memory_id = (tensor_device.type () == torch::kCPU ) ? 0  : tensor_device.index ();
1932+       const  auto  memory_type = (tensor_device.type () == torch::kCPU )
1933+                                    ? TRITONSERVER_MEMORY_CPU
1934+                                    : TRITONSERVER_MEMORY_GPU;
1935+       const  auto  memory_id =
1936+           (tensor_device.type () == torch::kCPU ) ? 0  : tensor_device.index ();
18931937
18941938      //  Batch output doesn't support string data type yet, as it is not trivial
18951939      //  to parse string output
@@ -1906,16 +1950,16 @@ ModelInstanceState::ReadOutputTensors(
19061950          return  TRITONSERVER_ErrorNew (
19071951              TRITONSERVER_ERROR_INVALID_ARG,
19081952              (std::string (" output '" 
1909-               " ' is a scalar which is not supported." 
1953+                 " ' is a scalar which is not supported." 
19101954                  .c_str ());
19111955        }
19121956
19131957        responder.ProcessTensor (
1914-             name, output_dtype, batchn_shape, output_buffer,
1915-             memory_type,  memory_id);
1958+             name, output_dtype, batchn_shape, output_buffer, memory_type, 
1959+             memory_id);
19161960      } else  {
19171961        responder.ProcessBatchOutput (
1918-           name, *batch_output, output_buffer, memory_type, memory_id);
1962+              name, *batch_output, output_buffer, memory_type, memory_id);
19191963      }
19201964    } else  if  (output_tensors[op_index].isList ()) {
19211965      //  Custom handling for string/bytes tensor...
0 commit comments