Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

训练过程中定时验证 #952

Open
ChenJian7578 opened this issue Mar 12, 2025 · 2 comments
Open

训练过程中定时验证 #952

ChenJian7578 opened this issue Mar 12, 2025 · 2 comments

Comments

@ChenJian7578
Copy link

请问现在的代码在lora训练的过程中,可以对额外单独的验证集进行定期验证吗?并保存在训练过程中在验证集上效果最好的权重

@yuecao0119
Copy link
Collaborator

你好,

我们代码的训练逻辑使用了Transformers库中的Trainer实现,但是还没有添加验证集的逻辑,如果你想实现对应的需求,可以参考:https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.eval_dataset

@2kxx
Copy link

2kxx commented Mar 15, 2025

你好,

我们代码的训练逻辑使用了Transformers库中的Trainer实现,但是还没有添加验证集的逻辑,如果你想实现对应的需求,可以参考:https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.eval_dataset

他们的实现中自定义了'pad_data_collator','concat_pad_data_collator'等方法,所以不能直接从Trainer中修改,得和train_dataset保持一致,可以参考我的实现运行,记得调整--per_device_eval_batch_size,不然默认为训练的4倍,会爆显存
`def build_datasets1(
data_args,
tokenizer,
tcs_loader,
model,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=12,
min_num_frame=8,
max_num_frame=32,
normalize_type='imagenet',
meta_path=None
):
datasets = []
lengths = []
data_rank = dist.get_rank()
data_world_size = dist.get_world_size()
ds_collections = json.loads(open(meta_path).read())
for ds_idx, ds_name in enumerate(ds_collections.keys()):
repeat_time = ds_collections[ds_name]['repeat_time']
if 'max_dynamic_patch' in ds_collections[ds_name]:
max_num = ds_collections[ds_name]['max_dynamic_patch']
logger.info(f'max_dynamic_patch is set to {max_num} according to the meta file')
else:
max_num = max_dynamic_patch
dataset = LazySupervisedDataset(
data_args.conv_style, ds_collections[ds_name],
tokenizer,
tcs_loader,
ds_name=ds_name,
num_image_token=model.num_image_token,
image_size=data_args.force_image_size,
is_train=ds_collections[ds_name]['data_augment'],
pad2square=data_args.pad2square,
group_by_length=group_by_length and not data_args.use_packed_ds,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_num,
min_num_frame=min_num_frame,
max_num_frame=max_num_frame,
repeat_time=repeat_time,
normalize_type=normalize_type,
# hyperparameters for packed training
use_packed_ds=data_args.use_packed_ds,
data_rank=data_rank,
data_world_size=data_world_size,
distributed_mode=data_args.use_packed_ds,
force_shuffle=data_args.use_packed_ds,
random_seed=ds_idx,
)
logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
datasets.append(dataset)
if data_args.use_data_resampling:
lengths.append(math.sqrt(len(dataset)))
else:
lengths.append(len(dataset))

if data_args.use_packed_ds:
    total_length = sum(lengths)
    train_dataset = PackedDataset(
        tokenizer=tokenizer,
        data_rank=data_rank,
        data_world_size=data_world_size,
        datasets=datasets,
        dataset_weight=[l / total_length for l in lengths],
        num_images_expected=data_args.num_images_expected,
        max_packed_tokens=data_args.max_packed_tokens,
        max_buffer_size=data_args.max_buffer_size,
        log_freq=data_args.log_freq,
        strict_mode=data_args.strict_mode,
        replacement=data_args.replacement,
        allow_overflow=data_args.allow_overflow,
        allow_deduplicated_ds_name=False,
    )
elif data_args.use_data_resampling:
    total_length = sum(lengths)
    weights = [l / total_length for l in lengths]
    train_dataset = WeightedConcatDataset(datasets, weights)
else:
    train_dataset = ConcatDataset(datasets)
return train_dataset

eval_dataset = build_datasets1(
data_args, tokenizer, tcs_loader, model, group_by_length=training_args.group_by_length,
dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame,
max_num_frame=data_args.max_num_frame, meta_path="/hd2/tzc/project/EvalMuse-internvl/internvl_chat/shell/data/evalmuse_eval.json")`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants