|
26 | 26 | from tensorflow.python.platform import gfile
|
27 | 27 | from tensorflow.python.util import serialization
|
28 | 28 | from tensorflow.python.util.tf_export import keras_export
|
| 29 | +from tensorflow.python import pywrap_tensorflow |
| 30 | + |
29 | 31 |
|
30 | 32 | import progressbar
|
31 | 33 | import tensorlayer as tl
|
|
76 | 78 | 'static_graph2net',
|
77 | 79 | # 'save_pkl_graph',
|
78 | 80 | # 'load_pkl_graph',
|
| 81 | + 'load_and_assign_ckpt', |
| 82 | + 'ckpt_to_npz_dict', |
79 | 83 | ]
|
80 | 84 |
|
81 | 85 |
|
@@ -2775,3 +2779,89 @@ def load_hdf5_to_weights(filepath, network, skip=False):
|
2775 | 2779 |
|
2776 | 2780 | f.close()
|
2777 | 2781 | logging.info("[*] Load %s SUCCESS!" % filepath)
|
| 2782 | + |
| 2783 | + |
| 2784 | +def load_and_assign_ckpt(model_dir, network=None, skip=True): |
| 2785 | + """Load weights by name from a given file of ckpt format |
| 2786 | +
|
| 2787 | + Parameters |
| 2788 | + ---------- |
| 2789 | + model_dir : str |
| 2790 | + Filename to which the weights will be loaded, should be of ckpt format. |
| 2791 | + Examples: model_dir = /root/cnn_model/ |
| 2792 | + network : Model |
| 2793 | + TL model. |
| 2794 | + skip : bool |
| 2795 | + If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False, |
| 2796 | + error will be raised when mismatch is found. Default False. |
| 2797 | +
|
| 2798 | + Returns |
| 2799 | + ------- |
| 2800 | +
|
| 2801 | + """ |
| 2802 | + model_dir = model_dir |
| 2803 | + model_path = None |
| 2804 | + for root, dirs, files in os.walk(model_dir): |
| 2805 | + for file in files: |
| 2806 | + filename, extension = os.path.splitext(file) |
| 2807 | + if extension in ['.data-00000-of-00001', '.index', '.meta']: |
| 2808 | + model_path = model_dir + '/' + filename |
| 2809 | + break |
| 2810 | + if model_path == None: |
| 2811 | + raise Exception('The ckpt file is not found') |
| 2812 | + |
| 2813 | + reader = pywrap_tensorflow.NewCheckpointReader(model_path) |
| 2814 | + var_to_shape_map = reader.get_variable_to_shape_map() |
| 2815 | + |
| 2816 | + net_weights_name = [w.name for w in network.all_weights] |
| 2817 | + |
| 2818 | + for key in var_to_shape_map: |
| 2819 | + if key not in net_weights_name: |
| 2820 | + if skip: |
| 2821 | + logging.warning("Weights named '%s' not found in network. Skip it." % key) |
| 2822 | + else: |
| 2823 | + raise RuntimeError( |
| 2824 | + "Weights named '%s' not found in network. Hint: set argument skip=Ture " |
| 2825 | + "if you want to skip redundant or mismatch weights." % key |
| 2826 | + ) |
| 2827 | + else: |
| 2828 | + assign_tf_variable(network.all_weights[net_weights_name.index(key)], reader.get_tensor(key)) |
| 2829 | + logging.info("[*] Model restored from ckpt %s" % filename) |
| 2830 | + |
| 2831 | + |
| 2832 | +def ckpt_to_npz_dict(model_dir, save_name='model.npz'): |
| 2833 | + """ Save ckpt weights to npz file |
| 2834 | +
|
| 2835 | + Parameters |
| 2836 | + ---------- |
| 2837 | + model_dir : str |
| 2838 | + Filename to which the weights will be loaded, should be of ckpt format. |
| 2839 | + Examples: model_dir = /root/cnn_model/ |
| 2840 | + save_name : str |
| 2841 | + The save_name of the `.npz` file. |
| 2842 | +
|
| 2843 | + Returns |
| 2844 | + ------- |
| 2845 | +
|
| 2846 | + """ |
| 2847 | + model_dir = model_dir |
| 2848 | + model_path = None |
| 2849 | + for root, dirs, files in os.walk(model_dir): |
| 2850 | + for file in files: |
| 2851 | + filename, extension = os.path.splitext(file) |
| 2852 | + if extension in ['.data-00000-of-00001', '.index', '.meta']: |
| 2853 | + model_path = model_dir + '/' + filename |
| 2854 | + break |
| 2855 | + if model_path == None: |
| 2856 | + raise Exception('The ckpt file is not found') |
| 2857 | + |
| 2858 | + reader = pywrap_tensorflow.NewCheckpointReader(model_path) |
| 2859 | + var_to_shape_map = reader.get_variable_to_shape_map() |
| 2860 | + |
| 2861 | + parameters_dict = {} |
| 2862 | + for key in sorted(var_to_shape_map): |
| 2863 | + parameters_dict[key] = reader.get_tensor(key) |
| 2864 | + np.savez(save_name, **parameters_dict) |
| 2865 | + parameters_dict = None |
| 2866 | + del parameters_dict |
| 2867 | + logging.info("[*] Ckpt weights saved in npz_dict %s" % save_name) |
0 commit comments