-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_reg_tree.cpp
86 lines (78 loc) · 2.99 KB
/
test_reg_tree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <time.h>
#include "RegTree.h"
int main(int argv, char* argc[]) {
std::cout << "Testing Regression Tree" << std::endl;
std::cout << "Loading the tree" << std::endl;
// ./Bin/reg_tree DATA/data 20 200 0
if (argv < 5) {
std::cout << "Usage: ./test_tree [training_file] [max_depth] [max_node_path] [multi thread or not]" << std::endl;
return 0;
}
// load feature
int sampleCnt = 0, featureCnt = 0;
int maxDepth, maxNodePath;
std::ifstream fis(argc[1]);
std::string line;
std::vector<float> vctLabel;
std::vector<std::vector<float> > vctFeature;
std::cout << "test error" << std::endl;
while (getline(fis, line)) {
sampleCnt += 1;
std::vector<float> tempFeature;
std::vector<std::string> vctSplitRes;
mla::util::split(line, '\t', vctSplitRes);
for (size_t i = 0; i < vctSplitRes.size(); i ++) {
if (i + 1 < vctSplitRes.size()) {
tempFeature.push_back(atof(vctSplitRes[i].c_str()));
} else {
vctLabel.push_back(atof(vctSplitRes[i].c_str()));
}
}
vctFeature.push_back(tempFeature);
}
int splitPos = (int)(0.8 * sampleCnt);
std::vector<int32_t> vCurrentIndex;
for (int32_t i = 0; i < splitPos; i ++) {
vCurrentIndex.push_back(i);
}
std::vector<std::vector<float> > vctTrainFeature(vctFeature.begin(), vctFeature.begin() + splitPos);
std::vector<std::vector<float> > vctTestFeature(vctFeature.begin() + splitPos, vctFeature.end());
std::vector<float> vctTrainLabel(vctLabel.begin(), vctLabel.begin() + splitPos);
std::vector<float> vctTestLabel(vctLabel.begin() + splitPos, vctLabel.end());
if (vctFeature.size() <= 0) {
std::cout << "Loading Feature Error!" << std::endl;
return 0;
}
featureCnt = vctFeature[0].size();
maxNodePath = atoi(argc[2]);
maxDepth = atoi(argc[3]);
bool isMultiThreadOn = atoi(argc[4]);
if (isMultiThreadOn) {
std::cout << "True" << std::endl;
} else {
std::cout << "False" << std::endl;
}
// train the regreTree
std::cout << "here" << std::endl;
mla::tree::RegressionTree* regreTree = new mla::tree::RegressionTree(maxDepth, maxNodePath, isMultiThreadOn, false);
regreTree->getMinSampleCnt() = 5;
if (regreTree->getMultiThreadOn()) {
std::cout << "ON" << std::endl;
} else {
std::cout << "NO" << std::endl;
}
regreTree->setData(vctTrainFeature, vctTrainLabel);
clock_t start = clock();
regreTree->train();
clock_t end = clock();
std::cout << "Cost:" << end - start << std::endl;
for (size_t i = 0; i < vctTrainFeature.size(); ++i) {
std::cout << i << " " << vctTrainLabel[i] << " vs " << regreTree->predict(vctTrainFeature[i]) << std::endl;
}
/*
for (size_t i = 0; i < vctTestFeature.size(); ++i) {
std::cout << vctTestLabel[i] << " vs " << regreTree->predict(vctTestFeature[i]) << std::endl;
}
*/
return 0;
}