@@ -125,20 +125,14 @@ def __init__(
125
125
) -> None :
126
126
super ().__init__ (input_path , batch_size , selected_cols , drop_remainder )
127
127
self .schema = []
128
- self ._ordered_cols = None
129
128
reader = common_io .table .TableReader (
130
129
self ._input_path .split ("," )[0 ],
131
- selected_cols = "," .join (self ._selected_cols or []),
132
130
)
133
- if self ._selected_cols :
134
- self ._ordered_cols = []
135
- for field in reader .get_schema ():
136
- # pyre-ignore [58]
137
- if field ["colname" ] in self ._selected_cols :
138
- self .schema .append (field )
139
- self ._ordered_cols .append (field ["colname" ])
140
- else :
141
- self .schema = reader .get_schema ()
131
+ self ._ordered_cols = []
132
+ for field in reader .get_schema ():
133
+ if not selected_cols or field ["colname" ] in selected_cols :
134
+ self .schema .append (field )
135
+ self ._ordered_cols .append (field ["colname" ])
142
136
reader .close ()
143
137
144
138
def _iter_one_table (
@@ -148,7 +142,7 @@ def _iter_one_table(
148
142
input_path ,
149
143
slice_id = worker_id ,
150
144
slice_count = num_workers ,
151
- selected_cols = "," .join (self ._selected_cols or []),
145
+ selected_cols = "," .join (self ._ordered_cols or []),
152
146
)
153
147
while True :
154
148
try :
0 commit comments