Skip to content

Latest commit

 

History

History
121 lines (78 loc) · 5.06 KB

File metadata and controls

121 lines (78 loc) · 5.06 KB

Attention Cluster 视频分类模型


目录

模型简介

Attention Cluster模型为ActivityNet Kinetics Challenge 2017中最佳序列模型。该模型通过带Shifting Opeation的Attention Clusters处理已抽取好的RGB、Flow、Audio特征数据,Attention Cluster结构如下图所示。


Multimodal Attention Cluster with Shifting Operation

Shifting Operation通过对每一个attention单元的输出添加一个独立可学习的线性变换处理后进行L2-normalization,使得各attention单元倾向于学习特征的不同成分,从而让Attention Cluster能更好地学习不同分布的数据,提高整个网络的学习表征能力。

详细内容请参考Attention Clusters: Purely Attention Based Local Feature Integration for Video Classification

数据准备

Attention Cluster模型使用2nd-Youtube-8M数据集, 数据下载及准备请参考数据说明

模型训练

数据准备完毕后,可以通过如下两种方式启动训练:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python train.py --model_name=AttentionCluster \
                --config=./configs/attention_cluster.yaml \
                --log_interval=10 \
                --valid_interval=1 \
                --use_gpu=True \
                --save_dir=./data/checkpoints \
                --fix_random_seed=False

bash run.sh train AttentionCluster ./configs/attention_cluster.yaml
  • 可下载已发布模型model通过--resume指定权重存放路径进行finetune等开发,或者在run.sh脚本中修改resume为解压之后的权重文件存放路径。

数据读取器说明: 模型读取Youtube-8M数据集中已抽取好的rgbaudio数据,对于每个视频的数据,均匀采样100帧,该值由配置文件中的seg_num参数指定。

模型设置: 模型主要可配置参数为cluster_numsseg_num参数,当配置cluster_nums为32, seg_num为100时,在Nvidia Tesla P40上单卡可跑batch_size=256

训练策略:

  • 采用Adam优化器,初始learning_rate=0.001。
  • 训练过程中不使用权重衰减。
  • 参数主要使用MSRA初始化

模型评估

可通过如下两种方式进行模型评估:

python eval.py --model_name=AttentionCluster \
               --config=./configs/attention_cluster.yaml \
               --log_interval=1 \
               --weights=$PATH_TO_WEIGHTS \
               --use_gpu=True

bash run.sh eval AttentionCluster ./configs/attention_cluster.yaml
  • 使用run.sh进行评估时,需要修改脚本中的weights参数指定需要评估的权重。

  • 若未指定--weights参数,脚本会下载已发布模型model进行评估

  • 评估结果以log的形式直接打印输出GAP、Hit@1等精度指标

  • 使用CPU进行评估时,请将use_gpu设置为False

当取如下参数时:

参数 取值
cluster_nums 32
seg_num 100
batch_size 2048
num_gpus 8

在2nd-YouTube-8M数据集下评估精度如下:

精度指标 模型精度
Hit@1 0.87
PERR 0.78
GAP 0.84

模型推断

可通过如下两种方式启动模型推断:

python predict.py --model_name=AttentionCluster \
                  --config=configs/attention_cluster.yaml \
                  --log_interval=1 \
                  --weights=$PATH_TO_WEIGHTS \
                  --filelist=$FILELIST \
                  --use_gpu=True

bash run.sh predict AttentionCluster ./configs/attention_cluster.yaml
  • 使用python命令行启动程序时,--filelist参数指定待推断的文件列表,如果不设置,默认为data/dataset/youtube8m/infer.list。--weights参数为训练好的权重参数,如果不设置,程序会自动下载已训练好的权重。这两个参数如果不设置,请不要写在命令行,将会自动使用默认值。

  • 使用run.sh进行评估时,请修改脚本中的weights参数指定需要用到的权重。

  • 若未指定--weights参数,脚本会下载已发布模型model进行推断

  • 模型推断结果以log的形式直接打印输出,可以看到每个测试样本的分类预测概率。

  • 使用CPU进行预测时,请将use_gpu设置为False

参考论文