Skip to content

Commit 06d50b5

Browse files
pritamdamaniafacebook-github-bot
authored andcommitted
Pull in fairscale.nn.Pipe into PyTorch. (pytorch#44090)
Summary: Pull Request resolved: pytorch#44090 This is an initial commit pulling in the torchgpipe fork at https://github.com/facebookresearch/fairscale. The purpose of this commit is to just pull in the code and ensure all tests and builds work fine. We will slowly modify this to match our intended API mentioned in https://fb.quip.com/txurAV3zIFox#RPZACAfAKMq. Follow up PRs would address further changes needed on top of the initial commit.. We're pulling the code into the `torch.distributed._pipeline.sync` package. The package is private on purpose since there is a lot of work (ex: docs, API changes etc.) that needs to go in before we can actually officially support this. ghstack-source-id: 114864254 Test Plan: 1) waitforbuildbot 2) Ran all tests on my devgpu Reviewed By: mrshenli Differential Revision: D23493316 fbshipit-source-id: fe3c8b7dadeeb86abdc00e8a8652491b0b16743a
1 parent b63ddd6 commit 06d50b5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+6532
-7
lines changed

LICENSE

+8-5
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,26 @@ Copyright (c) 2016-present, Facebook Inc. All rights reserved.
1616

1717
All contributions by Facebook:
1818
Copyright (c) 2016 Facebook Inc.
19-
19+
2020
All contributions by Google:
2121
Copyright (c) 2015 Google Inc.
2222
All rights reserved.
23-
23+
2424
All contributions by Yangqing Jia:
2525
Copyright (c) 2015 Yangqing Jia
2626
All rights reserved.
27-
27+
28+
All contributions by Kakao Brain:
29+
Copyright 2019-2020 Kakao Brain
30+
2831
All contributions from Caffe:
2932
Copyright(c) 2013, 2014, 2015, the respective contributors
3033
All rights reserved.
31-
34+
3235
All other contributions:
3336
Copyright(c) 2015, 2016 the respective contributors
3437
All rights reserved.
35-
38+
3639
Caffe2 uses a copyright model similar to Caffe: each contributor holds
3740
copyright over their contributions to Caffe2. The project versioning records
3841
all such contribution and copyright details. If a contributor wants to further

NOTICE

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ All contributions by Yangqing Jia:
2222
Copyright (c) 2015 Yangqing Jia
2323
All rights reserved.
2424

25+
All contributions by Kakao Brain:
26+
Copyright 2019-2020 Kakao Brain
27+
2528
All other contributions:
2629
Copyright(c) 2015, 2016 the respective contributors
2730
All rights reserved.
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Copyright 2019-2020 Kakao Brain
2+
3+
Redistribution and use in source and binary forms, with or without
4+
modification, are permitted provided that the following conditions are met:
5+
6+
1. Redistributions of source code must retain the above copyright
7+
notice, this list of conditions and the following disclaimer.
8+
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in the
11+
documentation and/or other materials provided with the distribution.
12+
13+
3. Neither the name of the copyright holder nor the names of its
14+
contributors may be used to endorse or promote products derived from this
15+
software without specific prior written permission.
16+
17+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27+
POSSIBILITY OF SUCH DAMAGE.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright 2019 Kakao Brain
2+
#
3+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4+
#
5+
# This source code is licensed under the BSD license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH.
8+
# See also: https://docs.pytest.org/en/latest/goodpractices.html
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2019 Kakao Brain
2+
#
3+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4+
#
5+
# This source code is licensed under the BSD license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
import pytest
8+
import torch
9+
10+
11+
@pytest.fixture(autouse=True)
12+
def manual_seed_zero():
13+
torch.manual_seed(0)
14+
15+
16+
@pytest.fixture(scope="session")
17+
def cuda_sleep():
18+
# Warm-up CUDA.
19+
torch.empty(1, device="cuda")
20+
21+
# From test/test_cuda.py in PyTorch.
22+
start = torch.cuda.Event(enable_timing=True)
23+
end = torch.cuda.Event(enable_timing=True)
24+
start.record()
25+
torch.cuda._sleep(1000000)
26+
end.record()
27+
end.synchronize()
28+
cycles_per_ms = 1000000 / start.elapsed_time(end)
29+
30+
def cuda_sleep(seconds):
31+
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
32+
33+
return cuda_sleep
34+
35+
36+
def pytest_report_header():
37+
return f"torch: {torch.__version__}"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright 2019 Kakao Brain
2+
#
3+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4+
#
5+
# This source code is licensed under the BSD license found in the
6+
# LICENSE file in the root directory of this source tree.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2019 Kakao Brain
2+
#
3+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4+
#
5+
# This source code is licensed under the BSD license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
import copy
8+
9+
from torch import nn
10+
11+
from torch.distributed._pipeline.sync.skip import Namespace, skippable, stash
12+
13+
14+
def test_namespace_difference():
15+
ns1 = Namespace()
16+
ns2 = Namespace()
17+
assert ns1 != ns2
18+
19+
20+
def test_namespace_copy():
21+
ns = Namespace()
22+
assert copy.copy(ns) == ns
23+
assert copy.copy(ns) is not ns
24+
25+
26+
def test_skippable_repr():
27+
@skippable(stash=["hello"])
28+
class Hello(nn.Module):
29+
def __init__(self):
30+
super().__init__()
31+
self.conv = nn.Conv2d(1, 1, 1)
32+
33+
def forward(self, x):
34+
yield stash("hello", x)
35+
return self.conv(x) # noqa
36+
37+
m = Hello()
38+
assert (
39+
repr(m)
40+
== """
41+
@skippable(Hello(
42+
(conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
43+
))
44+
""".strip()
45+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2019 Kakao Brain
2+
#
3+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4+
#
5+
# This source code is licensed under the BSD license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
import pytest
8+
import torch
9+
from torch import nn
10+
11+
from torch.distributed._pipeline.sync import Pipe
12+
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
13+
from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange
14+
15+
16+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
17+
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
18+
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
19+
def test_1to3(balance, checkpoint):
20+
if torch.cuda.device_count() < len(balance):
21+
pytest.skip("at least %d cuda devices required" % len(balance))
22+
23+
@skippable(stash=["1to3"])
24+
class Layer1(nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self.conv = nn.Conv2d(3, 3, 1)
28+
29+
def forward(self, input):
30+
yield stash("1to3", input)
31+
output = self.conv(input)
32+
return output # noqa
33+
34+
class Layer2(nn.Module):
35+
def __init__(self):
36+
super().__init__()
37+
self.conv = nn.Conv2d(3, 3, 1)
38+
39+
def forward(self, input):
40+
output = self.conv(input)
41+
return output
42+
43+
@skippable(pop=["1to3"])
44+
class Layer3(nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
self.conv = nn.Conv2d(3, 3, 1)
48+
49+
def forward(self, input):
50+
skip_1to3 = yield pop("1to3")
51+
output = self.conv(input) + skip_1to3
52+
return output
53+
54+
model = nn.Sequential(Layer1(), Layer2(), Layer3())
55+
model = Pipe(model, balance, chunks=3, checkpoint=checkpoint)
56+
57+
in_device = model.devices[0]
58+
out_device = model.devices[-1]
59+
60+
input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
61+
output = model(input)
62+
loss = output.mean()
63+
loss.backward()
64+
65+
assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1)
66+
assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))
67+
68+
69+
def test_none_skip():
70+
@skippable(stash=["none"])
71+
class Stash(nn.Module):
72+
def forward(self, input):
73+
yield stash("none", None)
74+
return input # noqa
75+
76+
@skippable(pop=["none"])
77+
class Pop(nn.Module):
78+
def forward(self, input):
79+
none = yield pop("none")
80+
assert none is None
81+
return input
82+
83+
model = nn.Sequential(Stash(), Pop())
84+
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=5)
85+
86+
input = torch.rand(10, requires_grad=True)
87+
output = model(input)
88+
89+
def assert_grad_fn_is_not_portal(grad_fn, visited=None):
90+
if visited is None:
91+
visited = set()
92+
if grad_fn in visited or grad_fn is None:
93+
return
94+
95+
assert not isinstance(grad_fn, PortalBlue._backward_cls)
96+
assert not isinstance(grad_fn, PortalCopy._backward_cls)
97+
assert not isinstance(grad_fn, PortalOrange._backward_cls)
98+
99+
visited.add(grad_fn)
100+
for next_grad_fn, _ in grad_fn.next_functions:
101+
assert_grad_fn_is_not_portal(next_grad_fn, visited)
102+
103+
assert_grad_fn_is_not_portal(output.grad_fn)
104+
105+
output.sum().backward()
106+
assert input.grad.mean().item() == 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2019 Kakao Brain
2+
#
3+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4+
#
5+
# This source code is licensed under the BSD license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
from torch import nn
8+
9+
from torch.distributed._pipeline.sync.skip import Namespace, pop, skippable, stash
10+
from torch.distributed._pipeline.sync.skip.layout import inspect_skip_layout
11+
12+
13+
class Pass(nn.Module):
14+
def forward(self, input):
15+
return input
16+
17+
18+
@skippable(stash=["foo"])
19+
class StashFoo(nn.Module):
20+
def forward(self, input):
21+
yield stash("foo", input)
22+
return input # noqa
23+
24+
25+
@skippable(pop=["foo"])
26+
class PopFoo(nn.Module):
27+
def forward(self, input):
28+
foo = yield stash("foo")
29+
return input + foo
30+
31+
32+
@skippable(stash=["bar"])
33+
class StashBar(nn.Module):
34+
def forward(self, input):
35+
yield stash("bar", input)
36+
return input # noqa
37+
38+
39+
@skippable(pop=["bar"])
40+
class PopBar(nn.Module):
41+
def forward(self, input):
42+
bar = yield pop("bar")
43+
return input + bar
44+
45+
46+
def test_no_skippables():
47+
p1 = nn.Sequential(Pass())
48+
p2 = nn.Sequential(Pass())
49+
50+
layout = inspect_skip_layout([p1, p2])
51+
policy = [list(layout.copy_policy(i)) for i in range(2)]
52+
53+
assert policy == [[], []]
54+
55+
56+
def test_inner_partition():
57+
p1 = nn.Sequential(StashFoo(), PopFoo())
58+
p2 = nn.Sequential(Pass())
59+
60+
layout = inspect_skip_layout([p1, p2])
61+
policy = [list(layout.copy_policy(i)) for i in range(2)]
62+
63+
assert policy == [[], []]
64+
65+
66+
def test_adjoining_partitions():
67+
p1 = nn.Sequential(StashFoo())
68+
p2 = nn.Sequential(PopFoo())
69+
70+
layout = inspect_skip_layout([p1, p2])
71+
policy = [list(layout.copy_policy(i)) for i in range(2)]
72+
73+
assert policy == [[], [(0, None, "foo")]]
74+
75+
76+
def test_far_partitions():
77+
p1 = nn.Sequential(StashFoo())
78+
p2 = nn.Sequential(Pass())
79+
p3 = nn.Sequential(PopFoo())
80+
81+
layout = inspect_skip_layout([p1, p2, p3])
82+
policy = [list(layout.copy_policy(i)) for i in range(3)]
83+
84+
assert policy == [[], [], [(0, None, "foo")]]
85+
86+
87+
def test_pop_2_from_different_partitions():
88+
p1 = nn.Sequential(StashFoo())
89+
p2 = nn.Sequential(StashBar())
90+
p3 = nn.Sequential(PopBar(), PopFoo())
91+
92+
layout = inspect_skip_layout([p1, p2, p3])
93+
policy = [list(layout.copy_policy(i)) for i in range(3)]
94+
95+
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
96+
assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]]
97+
98+
99+
def test_namespace():
100+
ns1 = Namespace()
101+
ns2 = Namespace()
102+
103+
p1 = nn.Sequential(StashFoo().isolate(ns1))
104+
p2 = nn.Sequential(StashFoo().isolate(ns2))
105+
p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1))
106+
107+
layout = inspect_skip_layout([p1, p2, p3])
108+
policy = [list(layout.copy_policy(i)) for i in range(3)]
109+
110+
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
111+
assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]]

0 commit comments

Comments
 (0)