Skip to content
This repository was archived by the owner on Mar 22, 2024. It is now read-only.

Commit 3a59d74

Browse files
authored
Update hdf5_to_csv.py
1 parent b866326 commit 3a59d74

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

deeprank_gnn/tools/hdf5_to_csv.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,30 @@ def hdf5_to_csv(hdf5_path):
2020
if len(targets) == 0:
2121
targets = 'n'*len(mol)
2222

23+
bin=False
2324

2425
# This section is specific to the classes
2526
# it adds the raw output, i.e. probabilities to belong to the class 0, the class 1, etc., to the prediction hdf5
2627
# This way, binary information can be transformed back to continuous data and used for ranking
2728
if 'raw_outputs' in hdf5['{}/{}'.format(epoch, dataset)].keys():
28-
if first :
29-
header = ['epoch', 'set', 'model', 'targets', 'prediction']
30-
output_file = open('{}.csv'.format(name), 'w')
31-
output_file.write(','+','.join(header)+'\n')
32-
output_file.close()
33-
first = False
34-
data_to_save = [epoch_lst, dataset_lst, mol, targets, outputs]
35-
for target_class in range(0,len(hdf5['{}/{}/raw_outputs'.format(epoch, dataset)][()][0,:])):
36-
# probability of getting 0
37-
outputs_per_class = hdf5['{}/{}/raw_outputs'.format(epoch, dataset)][()][:,target_class]
38-
data_to_save.append(outputs_per_class)
39-
header.append(f'raw_prediction_{target_class}')
40-
dataset_df = pd.DataFrame(list(zip(*data_to_save)), columns = header)
41-
42-
else:
29+
if len(hdf5['{}/{}/raw_outputs'.format(epoch, dataset)][()].shape) > 1:
30+
bin=True
31+
if first :
32+
header = ['epoch', 'set', 'model', 'targets', 'prediction']
33+
output_file = open('{}.csv'.format(name), 'w')
34+
output_file.write(','+','.join(header)+'\n')
35+
output_file.close()
36+
first = False
37+
data_to_save = [epoch_lst, dataset_lst, mol, targets, outputs]
38+
39+
for target_class in range(0,len(hdf5['{}/{}/raw_outputs'.format(epoch, dataset)][()])):
40+
# probability of getting 0
41+
outputs_per_class = hdf5['{}/{}/raw_outputs'.format(epoch, dataset)][()][:,target_class]
42+
data_to_save.append(outputs_per_class)
43+
header.append(f'raw_prediction_{target_class}')
44+
dataset_df = pd.DataFrame(list(zip(*data_to_save)), columns = header)
45+
46+
if bin==False:
4347
if first :
4448
header = ['epoch', 'set', 'model', 'targets', 'prediction']
4549
output_file = open('{}.csv'.format(name), 'w')

0 commit comments

Comments
 (0)