Skip to content

Commit 3256e10

Browse files
mjyun01gunho1123
andauthored
RNGDet implementation (#11145)
* Add files via upload * Rename train.py to train.py * Delete official/projects/RNGDET directory * Add files via upload * Update run_test.py * Update run_test_all.py * Update README.md * Update run_test.py * Delete official/projects/rngdet/only_eval_metric.py * Update run_test_all.py * Update README.md * Update README.md * Add files via upload * Update rngdet_input.py * Update run_test_all.py * Update do_train.sh * Update run_test_all.py * Clean for PR * Update rngdet.py * clean up for PR * clean up for PR * Update rngdet.py * Update README.md * Update README.md * Update README.md * Update rngdet_test.py * Update rngdet_test.py * Delete official/projects/rngdet/tasks/__pycache__ directory * Delete official/projects/rngdet/configs/__pycache__ directory * Delete official/projects/rngdet/dataloaders/__pycache__ directory * Delete official/projects/rngdet/eval/__pycache__ directory * Delete official/projects/rngdet/modeling/__pycache__ directory * Delete official/projects/rngdet/metric directory * Update run_test_all.py * Create requirements.txt --------- Co-authored-by: Gunho Park <[email protected]>
1 parent 6268053 commit 3256e10

23 files changed

+3703
-0
lines changed

official/projects/rngdet/README.md

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Road Network Graph Detection by Transformer
2+
3+
[![RNGDet](https://img.shields.io/badge/RNGDet-arXiv.2202.07824-B3181B?)](https://arxiv.org/abs/2202.07824)
4+
[![RNGDet++](https://img.shields.io/badge/RNGDet++-arXiv.2209.10150-B3181B?)](https://arxiv.org/abs/2209.10150)
5+
6+
## Environment setup
7+
The code can be run on multiple GPUs or TPUs with different distribution
8+
strategies. See the TensorFlow distributed training
9+
[guide](https://www.tensorflow.org/guide/distributed_training) for an overview
10+
of `tf.distribute`.
11+
12+
## Data preparation
13+
To download the dataset and generate labels, try the following command:
14+
15+
```
16+
cd data
17+
./prepare_dataset.bash
18+
```
19+
20+
To generate training samples, try the following command:
21+
22+
```
23+
python create_cityscale_tf_record.py \
24+
--dataroot ./dataset/ \
25+
--roi_size 128 \
26+
--image_size 2048 \
27+
--edge_move_ahead_length 30 \
28+
--num_queries 10 \
29+
--noise 8 \
30+
--max_num_frame 10000 \
31+
--num_shards 32
32+
```
33+
## Training
34+
To edit training options of RNGDet, you can edit following commands in do_train.sh :
35+
36+
```
37+
CUDA_VISIBLE_DEVICES=4 python3 train.py \
38+
--mode=train \
39+
--experiment=rngdet_cityscale \
40+
--model_dir=./CKPT_DIR_NAME \
41+
--config_file=./configs/experiments/cityscale_rngdet_r50_gpu.yaml \
42+
```
43+
44+
To start training, try the following command :
45+
```
46+
sh do_train.sh
47+
```
48+
49+
## Evaluation
50+
To evaluate one image with internal step visualization,
51+
52+
```
53+
python run_test.py -ckpt ./CKPT_DIR_NAME
54+
```
55+
56+
To evaluate all images in the test dataset, and see score(P-P, P-R, R-F) for each images,
57+
58+
```
59+
python run_test_all.py -ckpt ./CKPT_DIR_NAME
60+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
runtime:
2+
distribution_strategy: 'mirrored'
3+
mixed_precision_dtype: 'float32'
4+
num_gpus: 1
5+
task:
6+
train_data:
7+
dtype: 'float32'
8+
validation_data:
9+
dtype: 'float32'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
runtime:
2+
distribution_strategy: 'tpu'
3+
mixed_precision_dtype: 'float32'
4+
task:
5+
train_data:
6+
dtype: 'float32'
7+
validation_data:
8+
dtype: 'float32'
+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""DETR configurations."""
16+
17+
import dataclasses
18+
import os
19+
from typing import List, Optional, Union
20+
21+
from official.core import config_definitions as cfg
22+
from official.core import exp_factory
23+
from official.modeling import hyperparams
24+
from official.modeling import optimization
25+
from official.vision.configs import common
26+
from official.vision.configs import decoders
27+
from official.vision.configs import backbones
28+
#from official.projects.rngdet import optimization as optimization_detr
29+
30+
31+
@dataclasses.dataclass
32+
class DataConfig(cfg.DataConfig):
33+
"""Input config for training."""
34+
input_path: str = ''
35+
tfds_name: str = ''
36+
tfds_split: str = 'train'
37+
global_batch_size: int = 0
38+
is_training: bool = False
39+
dtype: str = 'float32'
40+
decoder: common.DataDecoder = dataclasses.field(default_factory=common.DataDecoder)
41+
shuffle_buffer_size: int = 10000
42+
file_type: str = 'tfrecord'
43+
drop_remainder: bool = True
44+
45+
46+
@dataclasses.dataclass
47+
class Losses(hyperparams.Config):
48+
lambda_cls: float = 1.0
49+
lambda_box: float = 5.0
50+
background_cls_weight: float = 0.2
51+
52+
@dataclasses.dataclass
53+
class Rngdet(hyperparams.Config):
54+
"""Rngdet model definations."""
55+
num_queries: int = 10
56+
hidden_size: int = 256
57+
num_classes: int = 2 # 0: vertices, 1: background
58+
num_encoder_layers: int = 6
59+
num_decoder_layers: int = 6
60+
input_size: List[int] = dataclasses.field(default_factory=list)
61+
roi_size: int = 128
62+
backbone: backbones.Backbone = dataclasses.field(default_factory=lambda:backbones.Backbone(
63+
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False)))
64+
decoder: decoders.Decoder = dataclasses.field(
65+
default_factory=lambda: decoders.Decoder(type='fpn', fpn=decoders.FPN())
66+
)
67+
min_level: int = 2
68+
max_level: int = 5
69+
norm_activation: common.NormActivation = dataclasses.field(default_factory=common.NormActivation)
70+
backbone_endpoint_name: str = '5'
71+
72+
73+
@dataclasses.dataclass
74+
class RngdetTask(cfg.TaskConfig):
75+
model: Rngdet = dataclasses.field(default_factory=Rngdet)
76+
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
77+
validation_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
78+
losses: Losses = dataclasses.field(default_factory=Losses)
79+
init_checkpoint: Optional[str] = None
80+
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
81+
per_category_metrics: bool = False
82+
83+
84+
#CITYSCALE_INPUT_PATH_BASE = 'gs://ghpark-tfrecords/cityscale'
85+
CITYSCALE_TRAIN_EXAMPLES = 420140
86+
#CITYSCALE_TRAIN_EXAMPLES = 10140
87+
CITYSCALE_INPUT_PATH_BASE = '/data2/cityscale/tfrecord'
88+
#CITYSCALE_TRAIN_EXAMPLES = 1900
89+
CITYSCALE_VAL_EXAMPLES = 5000
90+
91+
@exp_factory.register_config_factory('rngdet_cityscale')
92+
def rngdet_cityscale() -> cfg.ExperimentConfig:
93+
"""Config to get results that matches the paper."""
94+
train_batch_size = 64
95+
eval_batch_size = 64
96+
steps_per_epoch = CITYSCALE_TRAIN_EXAMPLES // train_batch_size
97+
train_steps = 50 * steps_per_epoch # 50 epochs
98+
config = cfg.ExperimentConfig(
99+
task=RngdetTask(
100+
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
101+
init_checkpoint_modules='backbone',
102+
model=Rngdet(
103+
input_size=[128, 128, 3],
104+
roi_size=128,
105+
norm_activation=common.NormActivation()),
106+
losses=Losses(),
107+
train_data=DataConfig(
108+
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
109+
#input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise-8-00000-of-00032.tfrecord*'),
110+
is_training=True,
111+
global_batch_size=train_batch_size,
112+
shuffle_buffer_size=1000,
113+
),
114+
validation_data=DataConfig(
115+
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
116+
is_training=False,
117+
global_batch_size=eval_batch_size,
118+
drop_remainder=False,
119+
)),
120+
trainer=cfg.TrainerConfig(
121+
train_steps=train_steps,
122+
validation_steps=CITYSCALE_VAL_EXAMPLES // eval_batch_size,
123+
steps_per_loop=steps_per_epoch,
124+
summary_interval=steps_per_epoch,
125+
checkpoint_interval=1*steps_per_epoch,
126+
validation_interval=1*steps_per_epoch,
127+
max_to_keep=1,
128+
best_checkpoint_export_subdir='best_ckpt',
129+
best_checkpoint_eval_metric='AP',
130+
optimizer_config=optimization.OptimizationConfig({
131+
'optimizer': {
132+
'type': 'adamw_experimental',
133+
'adamw_experimental': {
134+
'epsilon': 1.0e-08,
135+
'weight_decay': 1.0e-05,
136+
'global_clipnorm': -1.0,
137+
},
138+
},
139+
'learning_rate': {
140+
'type': 'polynomial',
141+
'polynomial': {
142+
'initial_learning_rate': 0.0001,
143+
'end_learning_rate': 0.000001,
144+
'offset': 0,
145+
'power': 1.0,
146+
'decay_steps': 50 * steps_per_epoch,
147+
},
148+
},
149+
'warmup': {
150+
'type': 'linear',
151+
'linear': {
152+
'warmup_steps': 2 * steps_per_epoch,
153+
'warmup_learning_rate': 0,
154+
},
155+
},
156+
})),
157+
restrictions=[
158+
'task.train_data.is_training != None',
159+
])
160+
return config
161+
162+
163+
164+
@exp_factory.register_config_factory('rngdet_cityscale_detr')
165+
def rngdet_cityscale() -> cfg.ExperimentConfig:
166+
"""Config to get results that matches the paper."""
167+
train_batch_size = 16
168+
eval_batch_size = 64
169+
steps_per_epoch = CITYSCALE_TRAIN_EXAMPLES // train_batch_size
170+
train_steps = 50 * steps_per_epoch # 50 epochs
171+
config = cfg.ExperimentConfig(
172+
task=RngdetTask(
173+
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
174+
init_checkpoint_modules='backbone',
175+
model=Rngdet(
176+
input_size=[128, 128, 3],
177+
roi_size=128,
178+
norm_activation=common.NormActivation()),
179+
losses=Losses(),
180+
train_data=DataConfig(
181+
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
182+
#input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise-8-00000-of-00032.tfrecord*'),
183+
is_training=True,
184+
global_batch_size=train_batch_size,
185+
shuffle_buffer_size=1000,
186+
),
187+
validation_data=DataConfig(
188+
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train_noise*'),
189+
is_training=False,
190+
global_batch_size=eval_batch_size,
191+
drop_remainder=False,
192+
)),
193+
trainer=cfg.TrainerConfig(
194+
train_steps=train_steps,
195+
validation_steps=CITYSCALE_VAL_EXAMPLES // eval_batch_size,
196+
steps_per_loop=steps_per_epoch,
197+
summary_interval=steps_per_epoch,
198+
checkpoint_interval=1*steps_per_epoch,
199+
validation_interval=1*steps_per_epoch,
200+
max_to_keep=1,
201+
best_checkpoint_export_subdir='best_ckpt',
202+
best_checkpoint_eval_metric='AP',
203+
optimizer_config=optimization.OptimizationConfig({
204+
'optimizer': {
205+
'type': 'adamw',
206+
'adamw': {
207+
'weight_decay_rate': 1e-5,
208+
'epsilon': 1e-08,
209+
'global_clipnorm': 0.1,
210+
# Avoid AdamW legacy behavior.
211+
'gradient_clip_norm': 0.0
212+
}
213+
},
214+
'learning_rate': {
215+
'type': 'stepwise',
216+
'stepwise': {
217+
'boundaries': [20 * steps_per_epoch,
218+
30 * steps_per_epoch,
219+
40 * steps_per_epoch],
220+
'values': [1.0e-05, 1.0e-05, 1.0e-06, 1.0e-07]
221+
}
222+
},
223+
})),
224+
restrictions=[
225+
'task.train_data.is_training != None',
226+
])
227+
return config
228+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for detr."""
16+
17+
# pylint: disable=unused-import
18+
from absl.testing import parameterized
19+
import tensorflow as tf
20+
21+
from official.core import config_definitions as cfg
22+
from official.core import exp_factory
23+
from official.projects.detr.configs import detr as exp_cfg
24+
from official.projects.detr.dataloaders import coco
25+
26+
27+
class DetrTest(tf.test.TestCase, parameterized.TestCase):
28+
29+
@parameterized.parameters(('detr_coco',))
30+
def test_detr_configs_tfds(self, config_name):
31+
config = exp_factory.get_exp_config(config_name)
32+
self.assertIsInstance(config, cfg.ExperimentConfig)
33+
self.assertIsInstance(config.task, exp_cfg.DetrTask)
34+
self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
35+
config.task.train_data.is_training = None
36+
with self.assertRaises(KeyError):
37+
config.validate()
38+
39+
@parameterized.parameters(('detr_coco_tfrecord'), ('detr_coco_tfds'))
40+
def test_detr_configs(self, config_name):
41+
config = exp_factory.get_exp_config(config_name)
42+
self.assertIsInstance(config, cfg.ExperimentConfig)
43+
self.assertIsInstance(config.task, exp_cfg.DetrTask)
44+
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
45+
config.task.train_data.is_training = None
46+
with self.assertRaises(KeyError):
47+
config.validate()
48+
49+
50+
if __name__ == '__main__':
51+
tf.test.main()

0 commit comments

Comments
 (0)