menu
- trival supervised learning (regression & classification)
- [TODO] seq2seq models (e.g. LLM)
- [TODO] reinforcement learning with human's feedback
supervised learning的三个重点:模型架构、数据处理、训练技巧,分别对应model.py, process.py和train.py
pytorch架构定义的类主要重写两个函数,分别是__init__()和forward(),分别为模型初始化的architecture和datapath
主要需要重写dataset, dataloader。可能包括data argument, padding和normalization(代码没写)
基本流程:
- 引入dataloader
- initialize model
- training looping a. training b. validating
可能包括动态学习率调度、指标计算保存、模型ckpt的保存