|
| 1 | +# A memory balanced and communication efficient FullyConnected layer model parallel implementation in PyTorch |
| 2 | + |
| 3 | +## Why we need model parallel? Why not use the DataParallel? |
| 4 | + |
| 5 | +Well, in face and re-id (person re-identification) areas, the number of labels in some private datasets may exceeds 1 million/10 millions/100 millions, and the parameters of the fully connected layer will occupy the whole GPU memory, and we can only use a small batch size which will result in slow training speed and poor evaluation performance |
| 6 | + |
| 7 | +## Fully connected layer with model parallel? It's simple! |
| 8 | + |
| 9 | +```python |
| 10 | +class FullyConnected(nn.Module): |
| 11 | + def __init__(self, in_dim, out_dim, num_gpu, model_parallel=False): |
| 12 | + super(FullyConnected, self).__init__() |
| 13 | + self.num_gpu = num_gpu |
| 14 | + self.model_parallel = model_parallel |
| 15 | + if model_parallel: |
| 16 | + self.fc_chunks = nn.ModuleList() |
| 17 | + for i in range(num_gpu): |
| 18 | + _class_num = out_dim // num_gpu |
| 19 | + if i < (out_dim % num_gpu): |
| 20 | + _class_num += 1 |
| 21 | + self.fc_chunks.append( |
| 22 | + nn.Linear(in_dim, _class_num, bias=False).cuda(i) |
| 23 | + ) |
| 24 | + else: |
| 25 | + self.classifier = nn.Linear(in_dim, out_dim, bias=False) |
| 26 | + |
| 27 | + def forward(self, x): |
| 28 | + if self.model_parallel: |
| 29 | + x_list = [] |
| 30 | + for i in range(self.num_gpu): |
| 31 | + _x = self.fc_chunks[i](x.cuda(i)) |
| 32 | + x_list.append(_x) |
| 33 | + x = torch.cat(x_list, dim=1) |
| 34 | + return x |
| 35 | + else: |
| 36 | + return self.classifier(x) |
| 37 | +``` |
| 38 | +Similar implementation can also be found [here](https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/d5e31893f7e30c0f82262e701463fd83d9725381/head/metrics.py#L41) |
| 39 | + |
| 40 | +this implementation can only solve part of the problem, and it will introduce new issues, GPU memory imbalanced usage between different gpus. Cause all the results will concat at GPU 0, and the loss calculation also happends at GPU 0, the GPU memory usage and computaion load will much higher in GPU 0 compare to other GPUs, we still can not use big batch size. |
| 41 | + |
| 42 | +## Did this repository solve the problem? |
| 43 | + |
| 44 | +Yes, and it extends to more occasions, like margin loss, mixed precison training and distributed training |
| 45 | + |
| 46 | +some advantages: |
| 47 | + |
| 48 | +- GPU memory usage and computation load will balanced among all GPUs, we can use a big batch size, life will be easier:-) |
| 49 | +- support most of the margin loss in face and re-id areas, like `ArcFace`, `SphereFace`, `CosFace`, `AM-softmax` and so on |
| 50 | +- it won't affect your evaluation result after training with the model parallel |
| 51 | +- sometimes speed up your training (due to lower communication cost in optimized CrossEntropyLoss) |
| 52 | +- support mixed precision training and distributed training |
| 53 | + |
| 54 | +## How can I use this? |
| 55 | + |
| 56 | +First make sure you do need model parallel: |
| 57 | + |
| 58 | +- If the number of labels in your datasets exceed 1 million? |
| 59 | +- If the last layer of your model is fully connected layer? And Did you use the CrossEntropyLoss? |
| 60 | +- If you have enough GPUs? (at least 4~8 GPUs) |
| 61 | + |
| 62 | +If the anwser of all the above questions is yes, and you can consider using the model parallel. But as it requires to hack into the model and optimizer, you will need to migrate this to your repository by yourself |
| 63 | + |
| 64 | +- normal training and mixed precison training, refer to [master branch](https://github.com/bindog/pytorch-model-parallel/tree/master) |
| 65 | +- distributed training, refer to [dist branch](https://github.com/bindog/pytorch-model-parallel/tree/dist) |
| 66 | + |
| 67 | +## what about other deep learning frameworks? |
| 68 | + |
| 69 | +the principle is the same, other frameworks like MXNet has a better support (kvstore) for distributed training |
| 70 | + |
| 71 | +# Chinese blogs |
| 72 | + |
| 73 | +- [http://bindog.github.io/blog/2019/09/05/gpu-memory-balanced-model-parallel/](http://bindog.github.io/blog/2019/09/05/gpu-memory-balanced-model-parallel/) |
| 74 | +- [http://bindog.github.io/blog/2020/04/12/model-parallel-with-apex/](http://bindog.github.io/blog/2020/04/12/model-parallel-with-apex/) |
0 commit comments