Skip to content

Commit d25aefa

Browse files
committed
combined cols to create a new col
1 parent 929c1cf commit d25aefa

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

fast_llm/data/preparator/gpt_memmap/config.py

+30
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,31 @@ def _validate(self) -> None:
109109
super()._validate()
110110
Assert.in_range(self.rank, 0, self.world_size)
111111

112+
@config_class
113+
class FieldCombinePreparatorConfig(Config):
114+
fields: list = Field(
115+
default=[],
116+
desc="Fields of the dataset to combine.",
117+
hint=FieldHint.core,
118+
)
119+
delimiter: str = Field(
120+
default=" ",
121+
desc="Delimiter to use when combining fields.",
122+
hint=FieldHint.optional,
123+
)
124+
new_field_name: str = Field(
125+
default="fast_llm_combined_field",
126+
desc="Name of the new field to create.",
127+
hint=FieldHint.optional,
128+
)
129+
130+
def _validate(self) -> None:
131+
Assert.gt(len(self.fields), 0)
132+
# assert isinstance(self.fields, list), "Fields must be a list."
133+
# assert all(isinstance(field, str) for field in self.fields), "All fields must be strings."
134+
assert isinstance(self.delimiter, str), "Delimiter must be a string."
135+
# assert isinstance(self.new_field_name, str), "New field name must be a string."
136+
super()._validate()
112137

113138
@config_class()
114139
class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
@@ -164,6 +189,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
164189
" Does not shuffle samples.",
165190
hint=FieldHint.optional,
166191
)
192+
combine_fields: FieldCombinePreparatorConfig = Field(
193+
default=None,
194+
desc="Combine all files into a single file.",
195+
hint=FieldHint.optional,
196+
)
167197

168198
def _validate(self) -> None:
169199
assert self.tokenizer.path is not None

fast_llm/data/preparator/gpt_memmap/prepare.py

+16
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,22 @@ def run(self) -> None:
208208
torch.distributed.barrier()
209209

210210
assert isinstance(dataset, datasets.Dataset)
211+
212+
# Check for combining fields
213+
if self._config.combine_fields:
214+
Assert.eq(len(set(self._config.combine_fields.fields).intersection(dataset.column_names)), len(self._config.combine_fields.fields))
215+
dataset = dataset.map(
216+
lambda example: {
217+
self._config.combine_fields.new_field_name: self._config.combine_fields.delimiter.join(
218+
str(example[column]) for column in self._config.combine_fields.fields
219+
)
220+
},
221+
batched=False,
222+
desc="Combining fields",
223+
)
224+
# Set the new field name in the config for following operations
225+
self._config.dataset.field = self._config.combine_fields.new_field_name
226+
211227
dataset = dataset.shard(
212228
num_shards=self._config.distributed.world_size,
213229
index=self._config.distributed.rank,

0 commit comments

Comments
 (0)