2323import paddle
2424from paddle .nn import Layer
2525# TODO(fangzeyang) Temporary fix and replace by paddle framework downloader later
26- from paddlenlp .utils .downloader import get_path_from_url
26+ from paddlenlp .utils .downloader import get_path_from_url , COMMUNITY_MODEL_PREFIX
2727from paddlenlp .utils .env import MODEL_HOME
2828from paddlenlp .utils .log import logger
2929
@@ -105,7 +105,7 @@ class is a pretrained model class adding layers on top of the base model,
105105 """
106106 model_config_file = "model_config.json"
107107 pretrained_init_configuration = {}
108- # TODO: more flexible resource handle, namedtuple with fileds as:
108+ # TODO: more flexible resource handle, namedtuple with fields as:
109109 # resource_name, saved_file, handle_name_for_load(None for used as __init__
110110 # arguments), handle_name_for_save
111111 resource_files_names = {"model_state" : "model_state.pdparams" }
@@ -115,7 +115,7 @@ class is a pretrained model class adding layers on top of the base model,
115115 def _wrap_init (self , original_init , * args , ** kwargs ):
116116 """
117117 It would be hooked after `__init__` to add a dict including arguments of
118- `__init__` as a attribute named `config` of the prtrained model instance.
118+ `__init__` as a attribute named `config` of the pretrained model instance.
119119 """
120120 init_dict = fn_args_to_dict (original_init , * ((self , ) + args ), ** kwargs )
121121 self .config = init_dict
@@ -135,6 +135,7 @@ def model_name_list(self):
135135 list: Contains all supported built-in pretrained model names of the
136136 current PretrainedModel class.
137137 """
138+ # Todo: return all model name
138139 return list (self .pretrained_init_configuration .keys ())
139140
140141 def get_input_embeddings (self ):
@@ -150,14 +151,18 @@ def get_output_embeddings(self):
150151 @classmethod
151152 def from_pretrained (cls , pretrained_model_name_or_path , * args , ** kwargs ):
152153 """
153- Creates an instance of `PretrainedModel` and load pretrained model weights
154- for it according to a specific model name (such as `bert-base-uncased`)
154+ Creates an instance of `PretrainedModel`. Model weights are loaded
155+ by specifying name of a built-in pretrained model, or a community contributed model,
155156 or a local file directory path.
156157
157158 Args:
158- pretrained_model_name_or_path (str): Name of pretrained model
159- for built-in pretrained models loading, such as `bert-base-uncased`.
160- Or a local file directory path for local trained models loading.
159+ pretrained_model_name_or_path (str): Name of pretrained model or dir path
160+ to load from. The string can be:
161+
162+ - Name of a built-in pretrained model
163+ - Name of a community-contributed pretrained model.
164+ - Local directory path which contains model weights file("model_state.pdparams")
165+ and model config file ("model_config.json").
161166 *args (tuple): Position arguments for model `__init__`. If provided,
162167 use these as position argument values for model initialization.
163168 **kwargs (dict): Keyword arguments for model `__init__`. If provided,
@@ -174,38 +179,47 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
174179
175180 from paddlenlp.transformers import BertForSequenceClassification
176181
182+ # Name of built-in pretrained model
177183 model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
184+
185+ # Name of community-contributed pretrained model
186+ model = BertForSequenceClassification.from_pretrained('yingyibiao/bert-base-uncased-sst-2-finetuned')
187+
188+ # Load from local directory path
189+ model = BertForSequenceClassification.from_pretrained('./my_bert/')
178190 """
179191 pretrained_models = list (cls .pretrained_init_configuration .keys ())
180192 resource_files = {}
181193 init_configuration = {}
182194
195+ # From built-in pretrained models
183196 if pretrained_model_name_or_path in pretrained_models :
184197 for file_id , map_list in cls .pretrained_resource_files_map .items ():
185198 resource_files [file_id ] = map_list [
186199 pretrained_model_name_or_path ]
187200 init_configuration = copy .deepcopy (
188201 cls .pretrained_init_configuration [
189202 pretrained_model_name_or_path ])
203+ # From local dir path
204+ elif os .path .isdir (pretrained_model_name_or_path ):
205+ for file_id , file_name in cls .resource_files_names .items ():
206+ full_file_name = os .path .join (pretrained_model_name_or_path ,
207+ file_name )
208+ resource_files [file_id ] = full_file_name
209+ resource_files ["model_config_file" ] = os .path .join (
210+ pretrained_model_name_or_path , cls .model_config_file )
190211 else :
191- if os .path .isdir (pretrained_model_name_or_path ):
192- for file_id , file_name in cls .resource_files_names .items ():
193- full_file_name = os .path .join (pretrained_model_name_or_path ,
194- file_name )
195- resource_files [file_id ] = full_file_name
196- resource_files ["model_config_file" ] = os .path .join (
197- pretrained_model_name_or_path , cls .model_config_file )
198- else :
199- raise ValueError (
200- "Calling {}.from_pretrained() with a model identifier or the "
201- "path to a directory instead. The supported model "
202- "identifiers are as follows: {}, but got: {}" .format (
203- cls .__name__ ,
204- cls .pretrained_init_configuration .keys (
205- ), pretrained_model_name_or_path ))
212+ # Assuming from community-contributed pretrained models
213+ for file_id , file_name in cls .resource_files_names .items ():
214+ full_file_name = os .path .join (COMMUNITY_MODEL_PREFIX ,
215+ pretrained_model_name_or_path ,
216+ file_name )
217+ resource_files [file_id ] = full_file_name
218+ resource_files ["model_config_file" ] = os .path .join (
219+ COMMUNITY_MODEL_PREFIX , pretrained_model_name_or_path ,
220+ cls .model_config_file )
206221
207222 default_root = os .path .join (MODEL_HOME , pretrained_model_name_or_path )
208-
209223 resolved_resource_files = {}
210224 for file_id , file_path in resource_files .items ():
211225 path = os .path .join (default_root , file_path .split ('/' )[- 1 ])
@@ -217,8 +231,18 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
217231 else :
218232 logger .info ("Downloading %s and saved to %s" %
219233 (file_path , default_root ))
220- resolved_resource_files [file_id ] = get_path_from_url (
221- file_path , default_root )
234+ try :
235+ resolved_resource_files [file_id ] = get_path_from_url (
236+ file_path , default_root )
237+ except RuntimeError as err :
238+ logger .error (err )
239+ raise RuntimeError (
240+ f"Can't load weights for '{ pretrained_model_name_or_path } '.\n "
241+ f"Please make sure that '{ pretrained_model_name_or_path } ' is:\n "
242+ "- a correct model-identifier of built-in pretrained models,\n "
243+ "- or a correct model-identifier of community-contributed pretrained models,\n "
244+ "- or the correct path to a directory containing relevant modeling files(model_weights and model_config).\n "
245+ )
222246
223247 # Prepare model initialization kwargs
224248 # Did we saved some inputs and kwargs to reload ?
@@ -292,7 +316,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
292316 model = cls (* derived_args , ** derived_kwargs )
293317
294318 # Maybe need more ways to load resources.
295- weight_path = list ( resolved_resource_files . values ())[ 0 ]
319+ weight_path = resolved_resource_files [ "model_state" ]
296320 assert weight_path .endswith (
297321 ".pdparams" ), "suffix of weight must be .pdparams"
298322 state_dict = paddle .load (weight_path )
0 commit comments