Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 243 additions & 0 deletions configs/gta_human/hmr/resnet50_hmr+_gta_bt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
_base_ = ['../_base_/default_runtime.py']
use_adversarial_train = True
# dist_params = dict(backend='nccl', port=29488)
dist_params = dict(backend='gloo')

# evaluate
evaluation = dict(interval=1, metric=['pa-mpjpe', 'mpjpe'])
# optimizer
optimizer = dict(
backbone=dict(type='Adam', lr=2.5e-5), head=dict(type='Adam', lr=2.5e-5))
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='Fixed', by_epoch=False)
runner = dict(type='EpochBasedRunner', max_epochs=150)

log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# checkpoint_config = dict(interval=5)
img_res = 224

# model settings

body_model = dict(
type='SMPL',
keypoint_src='smpl_54',
keypoint_dst='smpl_49',
keypoint_approximate=True,
model_path='data/body_models/smpl',
extra_joints_regressor='data/body_models/J_regressor_extra.npy')

registrant = dict(
type='SMPLify',
body_model=body_model,
num_epochs=1,
stages=[],
keypoints2d_loss=dict(
type='KeypointMSELoss', loss_weight=1.0, reduction='sum', sigma=100),
shape_prior_loss=dict(
type='ShapePriorLoss', loss_weight=5.0**2, reduction='sum'),
joint_prior_loss=dict(
type='JointPriorLoss', loss_weight=15.2**2, reduction='sum'),
pose_prior_loss=dict(
type='MaxMixturePrior',
prior_folder='data',
num_gaussians=8,
loss_weight=4.78**2,
reduction='sum'),
ignore_keypoints=[
'neck_openpose', 'right_hip_openpose', 'left_hip_openpose',
'right_hip_extra', 'left_hip_extra'
],
camera=dict(
type='PerspectiveCameras',
convention='opencv',
in_ndc=False,
focal_length=5000,
image_size=(img_res, img_res),
principal_point=(img_res / 2, img_res / 2)))
registration = dict(mode='static', registrant=registrant)

model = dict(
type='ImageBodyModelEstimator',
backbone=dict(
type='ResNet',
depth=50,
out_indices=[3],
norm_eval=False,
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
head=dict(
type='HMRHead',
feat_dim=2048,
smpl_mean_params='data/body_models/smpl_mean_params.npz'),
body_model_train=body_model,
body_model_test=dict(
type='SMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
registration=registration,
convention='smpl_49',
loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100),
loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10),
loss_vertex=dict(type='L1Loss', loss_weight=2),
loss_smpl_pose=dict(type='MSELoss', loss_weight=3),
loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02),
loss_camera=dict(type='CameraPriorLoss', loss_weight=60),
)
# dataset settings
dataset_type = 'HumanImageDataset'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data_keys = [
'has_smpl', 'smpl_body_pose', 'smpl_global_orient', 'smpl_betas',
'smpl_transl', 'keypoints2d', 'keypoints3d', 'is_flipped', 'center',
'scale', 'rotation', 'sample_idx'
]

file_client_args = dict(
backend='petrel', path_mapping=dict({'data/': 's3://mmhuman3d_datasets/'}))
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='RandomChannelNoise', noise_factor=0.4),
dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='smpl_49'),
dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25),
dict(type='MeshAffine', img_res=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=data_keys),
dict(
type='Collect',
keys=['img', *data_keys],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]
data_keys.remove('is_flipped')
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0),
dict(type='MeshAffine', img_res=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=data_keys),
dict(
type='Collect',
keys=['img', *data_keys],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]

inference_pipeline = [
dict(type='MeshAffine', img_res=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(
type='Collect',
keys=['img', 'sample_idx'],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]

cache_files = {
'h36m': 'data/cache/h36m_mosh_train_smpl_49.npz',
'mpi_inf_3dhp': 'data/cache/spin_mpi_inf_3dhp_train_smpl_49.npz',
'lsp': 'data/cache/spin_lsp_train_smpl_49.npz',
'lspet': 'data/cache/spin_lspet_train_smpl_49.npz',
'mpii': 'data/cache/spin_mpii_train_smpl_49.npz',
'coco': 'data/cache/spin_coco_2014_train_smpl_49.npz',
'gta': 'data/cache/gta_human_4x_smpl_49.npz'
}

data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
train=dict(
type='MixedDataset',
configs=[
dict(
type=dataset_type,
dataset_name='h36m',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_49',
cache_data_path=cache_files['h36m'],
ann_file='h36m_mosh_train.npz'),
dict(
type=dataset_type,
dataset_name='mpi_inf_3dhp',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_49',
cache_data_path=cache_files['mpi_inf_3dhp'],
ann_file='spin_mpi_inf_3dhp_train.npz'),
dict(
type=dataset_type,
dataset_name='lsp',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_49',
cache_data_path=cache_files['lsp'],
ann_file='spin_lsp_train.npz'),
dict(
type=dataset_type,
dataset_name='lspet',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_49',
cache_data_path=cache_files['lspet'],
ann_file='spin_lspet_train.npz'),
dict(
type=dataset_type,
dataset_name='mpii',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_49',
cache_data_path=cache_files['mpii'],
ann_file='spin_mpii_train.npz'),
dict(
type=dataset_type,
dataset_name='coco',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_49',
cache_data_path=cache_files['coco'],
ann_file='spin_coco_2014_train.npz'),
dict(
type=dataset_type,
dataset_name='gta',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_49',
cache_data_path=cache_files['gta'],
ann_file='gta_human_4x.npz'),
],
partition=[0.35, 0.15, 0.1, 0.10, 0.10, 0.2, 1],
),
test=dict(
type=dataset_type,
body_model=dict(
type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
data_prefix='data',
pipeline=test_pipeline,
convention='h36m',
ann_file='pw3d_test.npz'),
val=dict(
type=dataset_type,
body_model=dict(
type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
data_prefix='data',
pipeline=test_pipeline,
ann_file='pw3d_test.npz'))
42 changes: 42 additions & 0 deletions configs/gta_human/vibe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# GTA-Human (VIBE)

## Notes

- [SMPL](https://smpl.is.tue.mpg.de/) v1.0 is used in our experiments.
- [J_regressor_extra.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_extra.npy?versionId=CAEQHhiBgIDD6c3V6xciIGIwZDEzYWI5NTBlOTRkODU4OTE1M2Y4YTI0NTVlZGM1)
- [J_regressor_h36m.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_h36m.npy?versionId=CAEQHhiBgIDE6c3V6xciIDdjYzE3MzQ4MmU4MzQyNmRiZDA5YTg2YTI5YWFkNjRi)
- [smpl_mean_params.npz](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/smpl_mean_params.npz?versionId=CAEQHhiBgICN6M3V6xciIDU1MzUzNjZjZGNiOTQ3OWJiZTJmNThiZmY4NmMxMTM4)
- [gmm_08.pkl](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/gmm_08.pkl?versionId=CAEQHhiBgIDP6c3V6xciIGU4ZWFlYzlhNDJmODRmOGViYTMzOGRmODg2YjQ4NTg1)
- Pretrained SPIN model [spin_official.pth](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/data/pretrained_models/spin_official.pth?versionId=CAEQRBiBgMC3zJPvjhgiIDNjODIxODJjYzEyNzRmNDhhNzU3Nzg3N2FlY2Y0ZWMx) for extracting features. Rename it as spin.pth.

Download the above resources and arrange them in the following file structure:

```text
mmhuman3d
├── mmhuman3d
├── docs
├── tests
├── tools
├── configs
└── data
├── gmm_08.pkl
├── body_models
│ ├── J_regressor_extra.npy
│ ├── J_regressor_h36m.npy
│ ├── smpl_mean_params.npz
│ └── smpl
│ ├── SMPL_FEMALE.pkl
│ ├── SMPL_MALE.pkl
│ └── SMPL_NEUTRAL.pkl
├── pretrained
│ └── spin.pth
├── preprocessed_datasets
│ ├── vibe_insta_variety.npz
│ ├── vibe_mpi_inf_3dhp_train.npz
│ ├── vibe_pw3d_train.npz
│ ├── vibe_pw3d_test.npz
│ ├── vibe_gta_train.npz
│ └── vibe_gta_96k.npz


```
Loading