Skip to content

Commit 5834978

Browse files
committed
ADD load ckpt weights and ckpt to npz
1 parent 03c1098 commit 5834978

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

tensorlayer/files/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,6 @@
7272
#'load_graph',
7373
#'save_graph_and_params',
7474
#'load_graph_and_params',
75+
'load_and_assign_ckpt',
76+
'ckpt_to_npz_dict'
7577
]

tensorlayer/files/utils.py

+90
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from tensorflow.python.platform import gfile
2727
from tensorflow.python.util import serialization
2828
from tensorflow.python.util.tf_export import keras_export
29+
from tensorflow.python import pywrap_tensorflow
30+
2931

3032
import progressbar
3133
import tensorlayer as tl
@@ -76,6 +78,8 @@
7678
'static_graph2net',
7779
# 'save_pkl_graph',
7880
# 'load_pkl_graph',
81+
'load_and_assign_ckpt',
82+
'ckpt_to_npz_dict',
7983
]
8084

8185

@@ -2775,3 +2779,89 @@ def load_hdf5_to_weights(filepath, network, skip=False):
27752779

27762780
f.close()
27772781
logging.info("[*] Load %s SUCCESS!" % filepath)
2782+
2783+
2784+
def load_and_assign_ckpt(model_dir, network=None, skip=True):
2785+
"""Load weights by name from a given file of ckpt format
2786+
2787+
Parameters
2788+
----------
2789+
model_dir : str
2790+
Filename to which the weights will be loaded, should be of ckpt format.
2791+
Examples: model_dir = /root/cnn_model/
2792+
network : Model
2793+
TL model.
2794+
skip : bool
2795+
If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False,
2796+
error will be raised when mismatch is found. Default False.
2797+
2798+
Returns
2799+
-------
2800+
2801+
"""
2802+
model_dir = model_dir
2803+
model_path = None
2804+
for root, dirs, files in os.walk(model_dir):
2805+
for file in files:
2806+
filename, extension = os.path.splitext(file)
2807+
if extension in ['.data-00000-of-00001', '.index', '.meta']:
2808+
model_path = model_dir + '/' + filename
2809+
break
2810+
if model_path == None:
2811+
raise Exception('The ckpt file is not found')
2812+
2813+
reader = pywrap_tensorflow.NewCheckpointReader(model_path)
2814+
var_to_shape_map = reader.get_variable_to_shape_map()
2815+
2816+
net_weights_name = [w.name for w in network.all_weights]
2817+
2818+
for key in var_to_shape_map:
2819+
if key not in net_weights_name:
2820+
if skip:
2821+
logging.warning("Weights named '%s' not found in network. Skip it." % key)
2822+
else:
2823+
raise RuntimeError(
2824+
"Weights named '%s' not found in network. Hint: set argument skip=Ture "
2825+
"if you want to skip redundant or mismatch weights." % key
2826+
)
2827+
else:
2828+
assign_tf_variable(network.all_weights[net_weights_name.index(key)], reader.get_tensor(key))
2829+
logging.info("[*] Model restored from ckpt %s" % filename)
2830+
2831+
2832+
def ckpt_to_npz_dict(model_dir, save_name='model.npz'):
2833+
""" Save ckpt weights to npz file
2834+
2835+
Parameters
2836+
----------
2837+
model_dir : str
2838+
Filename to which the weights will be loaded, should be of ckpt format.
2839+
Examples: model_dir = /root/cnn_model/
2840+
save_name : str
2841+
The save_name of the `.npz` file.
2842+
2843+
Returns
2844+
-------
2845+
2846+
"""
2847+
model_dir = model_dir
2848+
model_path = None
2849+
for root, dirs, files in os.walk(model_dir):
2850+
for file in files:
2851+
filename, extension = os.path.splitext(file)
2852+
if extension in ['.data-00000-of-00001', '.index', '.meta']:
2853+
model_path = model_dir + '/' + filename
2854+
break
2855+
if model_path == None:
2856+
raise Exception('The ckpt file is not found')
2857+
2858+
reader = pywrap_tensorflow.NewCheckpointReader(model_path)
2859+
var_to_shape_map = reader.get_variable_to_shape_map()
2860+
2861+
parameters_dict = {}
2862+
for key in sorted(var_to_shape_map):
2863+
parameters_dict[key] = reader.get_tensor(key)
2864+
np.savez(save_name, **parameters_dict)
2865+
parameters_dict = None
2866+
del parameters_dict
2867+
logging.info("[*] Ckpt weights saved in npz_dict %s" % save_name)

tl

100755100644
File mode changed.

0 commit comments

Comments
 (0)