Skip to content

Commit b466d68

Browse files
committed
add model
1 parent 2858330 commit b466d68

File tree

7 files changed

+5591
-35
lines changed

7 files changed

+5591
-35
lines changed

README.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Mô hình được train bởi hàm tối ưu Adam với learning rate = 0.0001,
4747

4848
*Dự đoán*
4949

50-
Mỗi một sample được chia thành mỗi 250 frames, sử dụng phương pháp trích rút đặc trưng như mô tả ở trên, rồi đưa vào mạng CNN. Nhãn của file âm thanh được chọn bởi phương pháp Majority voting.
50+
Mỗi một sample được chia thành mỗi 250 frames, sử dụng phương pháp trích rút đặc trưng như mô tả ở trên, rồi đưa vào mạng CNN. Nhãn của file âm thanh được chọn bởi chiến thuật majority voting.
5151

5252

5353
# Cách sử dụng
@@ -62,10 +62,12 @@ pip install requirements.txt
6262

6363
## Huấn luyện mô hình
6464

65-
Để huấn luyện mô hình, chạy script `make_sample.py``train.py`
65+
Để huấn luyện mô hình, chạy script `preprocessing.py``train.py`
66+
67+
Chú ý: Dữ liệu train gồm có folder `train` cần đặt vào thư mục `data`
6668

6769
```
68-
python make_sample.py
70+
python preprocessing.py train
6971
python train.py
7072
```
7173

@@ -74,6 +76,7 @@ python train.py
7476
Để dự đoán, chạy script `predict.py`
7577

7678
```
79+
python preprocessing.py test
7780
python predict.py
7881
```
7982

__init__.py

Whitespace-only changes.

cnn.py

+5-23
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def reverse_transform(self, y):
3434
return
3535

3636

37-
train_data = joblib.load("tmp/zalo_data/train_full.data.bin")
38-
test_data = joblib.load("tmp/zalo_data/test.data.bin")
37+
train_data = joblib.load("tmp/train_full.data.bin")
38+
test_data = joblib.load("tmp/test.data.bin")
3939

4040
labels = []
4141
is_first = True
@@ -55,7 +55,7 @@ def reverse_transform(self, y):
5555
input_shape = X.shape[1:]
5656
num_classes = 6
5757
batch_size = 32
58-
epochs = 30
58+
epochs = 10
5959

6060
model = Sequential()
6161
model.add(Conv2D(64, kernel_size=(7, 7), strides=(1, 1), activation='relu', input_shape=input_shape, padding='same'))
@@ -81,26 +81,8 @@ def reverse_transform(self, y):
8181
validation_data=(X_test, y_test),
8282
callbacks=[early_stopping])
8383

84-
# Predictions
85-
import os
86-
prediction_filename = "submission.csv"
87-
try:
88-
os.remove(prediction_filename)
89-
except Exception:
90-
pass
91-
map_values = [(0, 1), (0, 0), (0, 2), (1, 1), (1, 0), (1, 2)]
92-
prediction_file = open(prediction_filename, "a")
93-
prediction_file.write("id,gender,accent\n")
94-
count_error_file = 0
95-
for label, X in test_data:
96-
try:
97-
value = np.bincount(np.argmax(model.predict(X), axis=1)).argmax()
98-
except:
99-
print(f"Cannot detect file {label}")
100-
value = 0
101-
count_error_file += 1
102-
gender, accent = map_values[value]
103-
prediction_file.write(f"{label},{gender},{accent}\n")
84+
model.save('model.h5')
85+
10486

10587
# evaluation("submission.csv", "data/public_test_gt.csv")
10688

model.h5

8.08 MB
Binary file not shown.

predict.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import os
22
import numpy as np
3-
43
import joblib
4+
from keras.models import load_model
5+
56

67
prediction_filename = "submission.csv"
78
try:
89
os.remove(prediction_filename)
910
except Exception:
1011
pass
1112

12-
model = None
13-
test_data = joblib.load("tmp/zalo_data/test.data.bin")
13+
model = load_model("model.h5")
14+
test_data = joblib.load("tmp/test.data.bin")
1415

1516
map_values = [(0, 1), (0, 0), (0, 2), (1, 1), (1, 0), (1, 2)]
1617
prediction_file = open(prediction_filename, "a")
@@ -24,4 +25,5 @@
2425
value = 0
2526
count_error_file += 1
2627
gender, accent = map_values[value]
27-
prediction_file.write(f"{label},{gender},{accent}\n")
28+
prediction_file.write(f"{label},{gender},{accent}\n")
29+
print(f"Results is saved in file {prediction_filename}")

preprocess.py preprocessing.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
from multiprocessing.pool import Pool
23
from os import listdir
34
import numpy as np
@@ -57,12 +58,13 @@ def make_train_data():
5758
n = len(files)
5859
features = list(tqdm.tqdm(p.imap(extract_features, files), total=n))
5960

60-
joblib.dump(features, "tmp/zalo_data/train_full.data.bin")
61+
joblib.dump(features, "tmp/train_full.data.bin")
6162
print(len(features))
6263

6364

6465
def make_test_data():
65-
TEST_FOLDER = "data/public_test"
66+
# TEST_FOLDER = "data/public_test"
67+
TEST_FOLDER = "/data"
6668
tmp = listdir(TEST_FOLDER)
6769
files = []
6870
for label in tmp:
@@ -72,9 +74,16 @@ def make_test_data():
7274
n = len(files)
7375
features = list(tqdm.tqdm(p.imap(extract_features, files), total=n))
7476

75-
joblib.dump(features, "tmp/zalo_data/test.data.bin")
77+
joblib.dump(features, "tmp/test.data.bin")
7678
print(len(features))
7779

7880

79-
make_train_data()
80-
make_test_data()
81+
parser = argparse.ArgumentParser("preprocessing.py")
82+
parser.add_argument("option", nargs="+", help="train or test")
83+
84+
args = parser.parse_args()
85+
mode = args.mode
86+
if mode == "train":
87+
make_train_data()
88+
elif mode == "test":
89+
make_test_data()

0 commit comments

Comments
 (0)