Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions src/jac-backend/FT_script.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "8321d173",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"!pip install virtualenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a5cd32a",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"!virtualenv myenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2508b93d",
"metadata": {},
"outputs": [],
"source": [
"!myenv/bin/pip install ipykernel"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b887351a",
"metadata": {},
"outputs": [],
"source": [
"!myenv/bin/python -m ipykernel install --user --name=myenv --display-name \"Python (myenv)\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef7f058c",
"metadata": {},
"outputs": [],
"source": [
"!python -m pip install -U jaclang\n",
"!pip install torch\n",
"!pip install datasets\n",
"!pip install transformers\n",
"!pip install trl\n",
"!pip install bert_score\n",
"!pip install unsloth\n",
"!pip install -U typing_extensions\n",
"!pip install --upgrade typing_extensions\n",
"!pip install tomli"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3027836",
"metadata": {},
"outputs": [],
"source": [
"%%writefile main.jac\n",
"import:py from unsloth, FastVisionModel ;\n",
"import:py torch ;\n",
"import:py from unsloth, is_bf16_supported ;\n",
"import:py from unsloth.trainer, UnslothVisionDataCollator ;\n",
"import:py from trl, SFTTrainer, SFTConfig ;\n",
"import:py from datasets, load_dataset ;\n",
"import:py from bert_score, score ;\n",
"import:py numpy as np ;\n",
"import:py from transformers, TextStreamer ;\n",
"\n",
"can load_model() {\n",
"\n",
" (model, tokenizer) = FastVisionModel.from_pretrained(\n",
" 'unsloth/Qwen2-VL-7B-Instruct',\n",
" load_in_4bit=True,\n",
" use_gradient_checkpointing='unsloth'\n",
" ); \n",
" model = FastVisionModel.get_peft_model(\n",
" model,\n",
" finetune_vision_layers=True,\n",
" finetune_language_layers=True,\n",
" finetune_attention_modules=True,\n",
" finetune_mlp_modules=True,\n",
" r=16,\n",
" lora_alpha=16,\n",
" lora_dropout=0,\n",
" bias='none',\n",
" random_state=3407,\n",
" use_rslora=False,\n",
" loftq_config=None\n",
" ); \n",
" return (model, tokenizer) ;\n",
"}\n",
"\n",
"can prep_train_dataset(num_img: Any) {\n",
"\n",
" instruction = 'Write the LaTeX representation for this image.'; \n",
" can convert_to_conversation(sample: Any) {\n",
"\n",
" conversation = [\n",
" {\n",
" 'role': 'user',\n",
" 'content': [{'type': 'text', 'text': instruction}, {'type': 'image', 'image': sample['image']}]\n",
" },\n",
" {\n",
" 'role': 'assistant',\n",
" 'content': [{'type': 'text', 'text': sample['text']}]\n",
" }\n",
" ]; \n",
" return {'messages': conversation} ;\n",
" }\n",
" \n",
" dataset = load_dataset('unsloth/LaTeX_OCR', split='train'); \n",
" dataset = dataset.select(range(num_img)); \n",
" converted_dataset = [convert_to_conversation(sample) for sample in dataset]; \n",
" return converted_dataset ;\n",
"}\n",
"\n",
"can prep_train_model(model: Any, tokenizer: Any, converted_dataset: Any, num_epochs: Any=5) {\n",
"\n",
" FastVisionModel.for_training(model) ; \n",
" trainer = SFTTrainer(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" data_collator=UnslothVisionDataCollator(model, tokenizer),\n",
" train_dataset=converted_dataset,\n",
" args=SFTConfig(\n",
" per_device_train_batch_size=2,\n",
" gradient_accumulation_steps=4,\n",
" warmup_steps=5,\n",
" #max_steps=30,\n",
" num_train_epochs=num_epochs,\n",
" learning_rate=0.0002,\n",
" fp16=not is_bf16_supported(),\n",
" bf16=is_bf16_supported(),\n",
" logging_steps=1,\n",
" optim='adamw_8bit',\n",
" weight_decay=0.01,\n",
" lr_scheduler_type='linear',\n",
" seed=3407,\n",
" output_dir='outputs',\n",
" report_to='none',\n",
" remove_unused_columns=False,\n",
" dataset_text_field='',\n",
" dataset_kwargs={'skip_prepare_dataset': True},\n",
" dataset_num_proc=4,\n",
" max_seq_length=2048\n",
" )\n",
" ); \n",
" return trainer ;\n",
"}\n",
"\n",
"can get_response(test_dataset: Any, model: Any, tokenizer: Any, n: Any) {\n",
"\n",
" FastVisionModel.for_inference(model) ; \n",
" image = test_dataset[n]['image']; \n",
" instruction = 'Write the LaTeX representation for this image.'; \n",
" messages = [\n",
" {\n",
" 'role': 'user',\n",
" 'content': [{'type': 'image'}, {'type': 'text', 'text': instruction}]\n",
" }\n",
" ]; \n",
" input_text = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True\n",
" ); \n",
" inputs = tokenizer(\n",
" image,\n",
" input_text,\n",
" add_special_tokens=False,\n",
" return_tensors='pt'\n",
" ).to('cuda'); \n",
" text_streamer = TextStreamer(tokenizer, skip_prompt=True); \n",
" output_id = model.generate(\n",
" **inputs,\n",
" streamer=text_streamer,\n",
" max_new_tokens=128,\n",
" use_cache=True,\n",
" temperature=1.5,\n",
" min_p=0.1 # use `top_p` instead of `min_p`\n",
" );\n",
" response = tokenizer.decode(output_id[0], skip_special_tokens=True); \n",
" return response ;\n",
"}\n",
"\n",
"can evaluate(response: Any, reference: Any) {\n",
"\n",
" (P, R, F1) = score(\n",
" [response],\n",
" [reference],\n",
" model_type='bert-base-uncased',\n",
" lang='en'\n",
" ); \n",
" return {\n",
" 'precision': P.mean().item(),\n",
" 'recall': R.mean().item(),\n",
" 'f1': F1.mean().item()\n",
" } ;\n",
"}\n",
"\n",
"with entry {\n",
"\n",
" (model, tokenizer) = load_model(); \n",
" train_dataset = prep_train_dataset(100); \n",
" trainer = prep_train_model(model, tokenizer, train_dataset); \n",
" trainer.train() ; \n",
" test_dataset = load_dataset('unsloth/LaTeX_OCR', split='test'); \n",
" response_dict = {'precision': 0, 'recall': 0, 'f1': 0}; \n",
" \n",
" for img in range(20) {\n",
" response = get_response(test_dataset, model, tokenizer, img); \n",
" reference = test_dataset[img]['text']; \n",
" accuracy = evaluate(response, reference); \n",
" response_dict['precision'] += accuracy['precision']; \n",
" response_dict['recall'] += accuracy['recall']; \n",
" response_dict['f1'] += accuracy['f1'];\n",
" }\n",
" \n",
" precision = (response_dict['precision'] / 20);\n",
" recall = (response_dict['recall'] / 20);\n",
" f1 = (response_dict['f1'] / 20);\n",
" \n",
" print(f\"{'VQA Accuracy: Precision: '}{precision}{', Recall: '}{recall}{', F1: '}{f1}\");\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3a7c4ec",
"metadata": {},
"outputs": [],
"source": [
"!jac run main.jac"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading