Skip to content

Latest commit

 

History

History
12 lines (10 loc) · 3.17 KB

adding_new_models.md

File metadata and controls

12 lines (10 loc) · 3.17 KB

Adding new models

Currently, you must clone the HELM repostory and modify the HELM code locally to add a new model.

  1. Pick a new organization name that does not conflict with any of the organizations for the existing models.
  2. Define a new Model describing your model and add it to ALL_MODELS in helm.proxy.models. Set Model.group and Model.creator_organization to your organization name. Set Model.name to "your_organization_name/your_model_name".
  3. Implement a new Client. The Client is responsible for implementing the core logic of your model, and responding to requests. In particular, it needs to implement three methods, Client.make_request(), Client.tokenize() and Client.decode(). Note that despite the name Client, it is possible for Client to run model inference locally. Refer to SimpleClient for an example implementation.
  4. Modify both AutoClient.get_client() and AutoClient.get_tokenizer_client() to return your implementation of Client for your organization name.
  5. Implement a WindowService. You will usually want to make your implementation a subclass of LocalWindowService, which will reuse the Client.tokenize() and Client.decode() methods that you implemented earlier. The subclass should implement the method LocalWindowService.tokenizer_name() to return "your_organization_name/your_model_name". The subclass of LocalWindowService will also need to implement the methods LocalWindowService.max_sequence_length(), LocalWindowService.max_sequence_length(), LocalWindowService.end_of_text_token(), LocalWindowService.prefix_token() and LocalWindowService.tokenizer_name(). Refer to GPT2WindowService for an example implementation.
  6. Modify WindowServiceFactory.get_window_service() to construct and return your implementation of WindowService for your organization name.
  7. Add an entry to schema.yaml describing your model.
  8. Update contamination.yaml to indicate if the training set of your model has been contaminated by data from any scenarios.