-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmapping.py
90 lines (77 loc) · 2.78 KB
/
mapping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#----------------description----------------#
# Author : Zihao Zhao
# E-mail : [email protected]
# Company : Fudan University
# Date : 2020-12-20 11:52:22
# LastEditors : Zihao Zhao
# LastEditTime : 2021-05-08 12:37:45
# FilePath : /pytorch-asr-wavenet/mapping.py
# Description :
#-------------------------------------------#
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
# import deepdish as dd
# import config_train as cfg
# from dataset import VCTK
# import dataset
# from wavenet import WaveNet
# from sparsity import *
# import utils
# import visualize as vis
# from ctcdecode import CTCBeamDecoder
# from tensorboardX import SummaryWriter
import os
import numpy as np
# import time
# import argparse
# from write_excel import *
model_pth = "/Users/zzh/Nutstore Files/Server-Code/DLA-explorers/DLA-mapper/data/model/wavenet/wavenet_dense.pth"
pattern_dir = "/Users/zzh/Nutstore Files/Server-Code/DLA-explorers/DLA-c-model/tests/data/wavenet/pattern"
save_dir = "/Users/zzh/Nutstore Files/Server-Code/DLA-explorers/DLA-c-model/tests/data/wavenet/weights"
def main():
print("Mapping...")
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_weight_txt(model_pth, save_dir)
def save_weight_txt(model_pth, folder):
name_list = list()
para_list = list()
model_weights = torch.load(model_pth, map_location=torch.device('cpu'))
for name, raw_w in model_weights.items():
# pytorch OC, IC, K
# C model K, IC, OC
raw_w_save = np.array(raw_w)
if name.split(".")[-2] != "bn" \
and name.split(".")[-2] != "bn2" \
and name.split(".")[-2] != "bn3" \
and name.split(".")[-1] != "bias":
# print(name)
# print(raw_w_save.shape)
raw_w_save = raw_w_save.transpose((2, 1, 0))
print(os.path.join(folder, name + '.txt'))
np.savetxt(os.path.join(folder, name + '.txt'), raw_w_save.flatten())
def read_txt():
layer = "/module.resnet_block_0.0.conv_filter.dilation_conv1d.weight.txt"
pattern_txt = pattern_dir + layer
weight_txt = save_dir + layer
pattern = np.loadtxt(pattern_txt).reshape((16, 8, 8))
weight = np.loadtxt(weight_txt).reshape((7, 128, 128))
# weight = np.loadtxt(weight_txt).reshape((1, 40, 128))
for i in range(0, 16):
for j in range(0, 8):
for k in range(0, 8):
print(pattern[i, j, k], end = ' ')
print(" ")
print(i)
print("w")
for j in range(0, 8):
for k in range(0, 8):
print(weight[0, j, k], end = ' ')
print(" ")
if __name__ == "__main__":
main()
# read_txt()