Unofficial implementation of Self-Critical Sequence Training (SCST) and various multi-head attention mechanisms.
This repo contains experimental and unofficial implementation of image captioning frameworks including:
- Self-Critical Sequence Training (SCST) [arxiv]
- Sampling is done via beam search [arxiv]
- Multi-Head Visual Attention
- Graph-based Beam Search, Greedy Search and Sampling
The features might not be completely tested. For a more stable implementation, please refer to this repo.
- tensorflow 1.9.0
- python 2.7
- java 1.8.0
- tqdm >= 4.24.0
- Pillow >= 3.1.2
- requests >= 2.18.4
More examples are given in example.sh.
Run ./src/setup.sh. This will download the required Stanford models
and run all the dataset pre-processing scripts.
The training scheme is as follows:
- Start with
decodermode (freezing the CNN) - Followed by
cnn_finetunemode - Finally,
scstmode
# MS-COCO
for mode in 'decoder' 'cnn_finetune' 'scst'
do
python train.py \
--train_mode ${mode} \
--token_type 'word' \
--cnn_fm_projection 'tied' \
--attn_num_heads 8
done
# InstaPIC
for mode in 'decoder' 'cnn_finetune' 'scst'
do
python train.py \
--train_mode ${mode} \
--dataset_file_pattern 'insta_{}_v25595_s15' \
--token_type 'word' \
--cnn_fm_projection 'independent' \
--attn_num_heads 8
doneJust point infer.py to the directory containing the checkpoints.
Model configurations are loaded from config.pkl.
# MS-COCO
python infer.py \
--infer_checkpoints_dir 'mscoco/word_add_softmax_h8_tie_lstm_run_01'
# InstaPIC
python infer.py \
--infer_checkpoints_dir 'insta/word_add_softmax_h8_ind_lstm_run_01' \
--dataset_dir '/path/to/insta/dataset' \
--annotations_file 'insta_testval_raw.json'- Main:
train_mode: The training regime. Choices aredecoder,cnn_finetune,scst.token_type: Language model. Choices areword,radix,char.
- CNN:
cnn_name: CNN model name.cnn_input_size: CNN input size.cnn_fm_attention: End point name of feature map for attention.cnn_fm_projection: Feature map projection method. Choices arenone,independent,tied.
- RNN:
rnn_name: Type of RNN. Choices areLSTM,LN_LSTM,GRU.rnn_size: Number of RNN units.rnn_word_size: Size of word embedding.rnn_init_method: RNN init method. Choices areproject_hidden,first_input.rnn_recurr_dropout: IfTrue, enable variational recurrent dropout.
- Attention:
attn_num_heads: Number of attention heads.attn_context_layer: IfTrue, add linear projection after multi-head attention.attn_alignment_method: Alignment / composition method. Choices areadd_LN,add,dot.attn_probability_fn: Attention map probability function. Choices aresoftmax,sigmoid.
- SCST:
scst_beam_size: The beam size for SCST sampling.scst_weight_ciderD: The weight for CIDEr-D metric during SCST training.scst_weight_bleu: The weight for BLEU metrics during SCST training.
- Main:
infer_set: The split to perform inference on. Choices aretest,valid,coco_test,coco_valid.coco_testandcoco_validare for inferencing on the wholetest2014andval2014sets respectively. These are used for MS-COCO online server evaluation.infer_checkpoints_dir: Directory containing the checkpoint files.infer_checkpoints: Checkpoint numbers to be evaluated. Comma-separated.annotations_file: Annotations / reference file for calculating scores.
- Inference parameters:
infer_beam_size: Beam size of beam search. Pass1for greedy search.infer_length_penalty_weight: Length penalty weight used in beam search.infer_max_length: Maximum caption length allowed during inference.batch_size_infer: Inference batch size for parallelism.
Re-downloading can be avoided by:
- Editing
setup.sh - Providing the path to the directory containing the dataset files
python coco_prepro.py --dataset_dir /path/to/coco/dataset
python insta_prepro.py --dataset_dir /path/to/insta/datasetIn the same way, both train.py and infer.py accept alternative dataset paths.
python train.py --dataset_dir /path/to/dataset
python infer.py --dataset_dir /path/to/datasetThis code assumes the following dataset directory structures:
{coco-folder}
+-- captions
| +-- {folder and files generated by coco_prepro.py}
+-- test2014
| +-- {image files}
+-- train2014
| +-- {image files}
+-- val2014
+-- {image files}
{insta-folder}
+-- captions
| +-- {folder and files generated by insta_prepro.py}
+-- images
| +-- {image files}
+-- json
+-- insta-caption-test1.json
+-- insta-caption-train.json
.
+-- common
| +-- {shared libraries and utility functions}
+-- datasets
| +-- preprocessing
| +-- {dataset pre-processing scripts}
+-- pretrained
| +-- {pre-trained checkpoints for some COMIC models. Details are provided in a separate README.}
+-- src
+-- {main scripts}
Thanks to the developers of:
- [attend2u]
- [coco-caption]
- [ruotianluo/self-critical.pytorch]
- [ruotianluo/cider]
- [weili-ict/SelfCriticalSequenceTraining-tensorflow]
- [tensorflow]
The project is open source under Apache-2.0 license (see the LICENSE file).