Skip to content

Commit e363d73

Browse files
authored
Add utilities for conv1d support (#78)
* add flattened conv1d Signed-off-by: Hao Wu <[email protected]>
1 parent 7d604f1 commit e363d73

File tree

4 files changed

+216
-0
lines changed

4 files changed

+216
-0
lines changed

docs/apidocs/utils.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,10 @@ emerging_optimizers.utils.eig
1414
=============================
1515
.. automodule:: emerging_optimizers.utils.eig
1616
:members:
17+
18+
19+
emerging_optimizers.utils.modules
20+
=================================
21+
.. automodule:: emerging_optimizers.utils.modules
22+
:members:
1723
```
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import math
17+
from typing import Any, Self
18+
19+
import torch
20+
import torch.nn as nn
21+
import torch.nn.functional as F
22+
23+
24+
class Conv1dFlatWeights(nn.Conv1d):
25+
"""Conv1d with weights+bias stored in a single 2D tensor
26+
27+
There are conv1d used in some LLM, in mamba mixer for example. Because the weight is not 2d, we cannot apply
28+
many of the emerging optimizers originally introduced for 2d weights of Linear layers without bias. Since
29+
convolution can be viewed as a matrix multiplication with im2col (either implicit or explicit), we can flatten
30+
the weight into a single 2D tensor and then apply the emerging optimizers to it.
31+
32+
Bias is not commonly used in most LLM's anymore, but they are often included in this type of conv1d.
33+
Since bias is mathematically the 0 order term of the polynomial, we can combine weight and bias into a
34+
single 2D tensor.
35+
36+
Arguments are the same as ::class:`torch.nn.Conv1d`.
37+
38+
Note:
39+
This implementation potentially introduces a small overhead because of split weights can combining gradients
40+
of it. This should be trivial compared to computational cost of LLM training. If it becomes a concern, a
41+
kernel can be developed to eliminate the overhead.
42+
43+
Note:
44+
Similar flattening logic can be applied to N-D convolution. But since we don't have use cases of them in LLM
45+
yet, they are not supported despite the __init__() function is generalized enough to support N-D convolution.
46+
47+
"""
48+
49+
def __init__(self, *args: Any, **kwargs: Any) -> None:
50+
super().__init__(*args, **kwargs)
51+
52+
assert self.padding_mode == "zeros", "Only zeros padding is supported"
53+
54+
self.weight: nn.Parameter[torch.Tensor]
55+
self.bias: nn.Parameter[torch.Tensor] | None | str
56+
57+
flat_weight_shape = [self.out_channels, math.prod(self.weight.shape[1:])]
58+
if self.bias is not None:
59+
flat_weight_shape[1] += 1
60+
flat_weight_buffer = torch.empty(flat_weight_shape, device=self.weight.device, dtype=self.weight.dtype)
61+
if self.bias is not None:
62+
flat_weight_buffer[..., :-1].copy_(self.weight.view(self.out_channels, -1))
63+
flat_weight_buffer[..., -1].copy_(self.bias)
64+
del self.bias
65+
self.has_bias = True
66+
self.bias = "dummy" # Trick con1d.extra_repr() to not print bias=False
67+
else:
68+
flat_weight_buffer.copy_(self.weight.view(self.out_channels, -1))
69+
self.has_bias = False
70+
del self.weight
71+
72+
self.weight = nn.Parameter(flat_weight_buffer)
73+
74+
@classmethod
75+
def from_conv1d(cls, conv1d: nn.Conv1d) -> Self:
76+
conv1d_flat = cls(
77+
in_channels=conv1d.in_channels,
78+
out_channels=conv1d.out_channels,
79+
kernel_size=conv1d.kernel_size,
80+
bias=conv1d.bias is not None,
81+
stride=conv1d.stride,
82+
padding=conv1d.padding,
83+
dilation=conv1d.dilation,
84+
groups=conv1d.groups,
85+
padding_mode=conv1d.padding_mode,
86+
device=conv1d.weight.device,
87+
dtype=conv1d.weight.dtype,
88+
)
89+
90+
if conv1d.bias is not None:
91+
conv1d_flat.weight.data[..., :-1].copy_(conv1d.weight.data.view(conv1d.out_channels, -1))
92+
conv1d_flat.weight.data[..., -1].copy_(conv1d.bias.data)
93+
else:
94+
conv1d_flat.weight.data.copy_(conv1d.weight.data.view(conv1d.out_channels, -1))
95+
return conv1d_flat
96+
97+
@property
98+
def weight_shape(self) -> tuple[int, int, int]:
99+
return (self.out_channels, self.in_channels // self.groups, self.kernel_size[0])
100+
101+
def forward(self, x: torch.Tensor) -> torch.Tensor:
102+
if self.has_bias:
103+
weight = self.weight[..., :-1].view(self.weight_shape)
104+
bias = self.weight[..., -1]
105+
else:
106+
weight = self.weight.view(self.weight_shape)
107+
bias = None
108+
109+
return F.conv1d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
110+
111+
def extra_repr(self) -> str:
112+
base_repr = super().extra_repr()
113+
return f"{base_repr}, flattened_param_shape={tuple(self.weight.shape)}"

tests/ci/L0_Tests_GPU.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ coverage run -p --source=emerging_optimizers tests/normalized_optimizer_converge
3030
coverage run -p --source=emerging_optimizers tests/test_psgd_contractions.py --device=cuda -v -2 || error=1
3131
coverage run -p --source=emerging_optimizers tests/test_psgd_utils.py --device=cuda -v -2 || error=1
3232
coverage run -p --source=emerging_optimizers tests/test_psgd_convergence.py --device=cuda -v -2 || error=1
33+
coverage run -p --source=emerging_optimizers tests/test_utils_modules.py -v -2 || error=1
3334

3435
exit "${error}"

tests/test_utils_modules.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
import torch.nn as nn
18+
from absl.testing import absltest, parameterized
19+
20+
from emerging_optimizers.utils.modules import Conv1dFlatWeights
21+
22+
23+
class TestConv1dFlatWeights(parameterized.TestCase):
24+
@parameterized.product(
25+
in_channels=[3, 5, 7],
26+
out_channels=[4, 6, 8],
27+
kernel_size=[2, 3, 4],
28+
bias=[False, True],
29+
batch_size=[4, 5, 6],
30+
)
31+
def test_matches_conv1d(self, in_channels, out_channels, kernel_size, bias, batch_size):
32+
kwargs = dict(
33+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias, device="cuda"
34+
)
35+
torch.manual_seed(42)
36+
conv = nn.Conv1d(**kwargs)
37+
torch.manual_seed(42)
38+
conv_flat = Conv1dFlatWeights(**kwargs)
39+
40+
self.assertEqual(conv_flat.weight.dim(), 2)
41+
42+
x = torch.randn(batch_size, in_channels, kernel_size, device="cuda")
43+
y_ref = conv(x)
44+
y_test = conv_flat(x)
45+
46+
torch.testing.assert_close(y_ref, y_test, atol=0, rtol=0)
47+
48+
y_ref.sum().backward()
49+
y_test.sum().backward()
50+
if bias:
51+
torch.testing.assert_close(
52+
conv.weight.grad.view(-1), conv_flat.weight.grad[:, :-1].reshape(-1), atol=0, rtol=0
53+
)
54+
torch.testing.assert_close(conv.bias.grad, conv_flat.weight.grad[:, -1], atol=0, rtol=0)
55+
else:
56+
torch.testing.assert_close(conv.weight.grad.view(-1), conv_flat.weight.grad.reshape(-1), atol=0, rtol=0)
57+
58+
@parameterized.product(
59+
bias=[False, True],
60+
)
61+
def test_extra_repr(self, bias):
62+
conv_flat = Conv1dFlatWeights(in_channels=3, out_channels=4, kernel_size=2, bias=bias)
63+
print(conv_flat)
64+
65+
@parameterized.product(
66+
in_channels=[3, 5, 7],
67+
out_channels=[4, 6, 8],
68+
kernel_size=[2, 3, 4],
69+
bias=[False, True],
70+
batch_size=[4, 5, 6],
71+
)
72+
def test_from_conv1d(self, in_channels, out_channels, kernel_size, bias, batch_size):
73+
kwargs = dict(
74+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias, device="cuda"
75+
)
76+
torch.manual_seed(42)
77+
conv = nn.Conv1d(**kwargs)
78+
torch.manual_seed(42)
79+
conv_flat = Conv1dFlatWeights.from_conv1d(conv)
80+
x = torch.randn(batch_size, in_channels, kernel_size, device="cuda")
81+
y_ref = conv(x)
82+
y_test = conv_flat(x)
83+
torch.testing.assert_close(y_ref, y_test, atol=0, rtol=0)
84+
y_ref.sum().backward()
85+
y_test.sum().backward()
86+
if bias:
87+
torch.testing.assert_close(
88+
conv.weight.grad.view(-1), conv_flat.weight.grad[:, :-1].reshape(-1), atol=0, rtol=0
89+
)
90+
torch.testing.assert_close(conv.bias.grad, conv_flat.weight.grad[:, -1], atol=0, rtol=0)
91+
else:
92+
torch.testing.assert_close(conv.weight.grad.view(-1), conv_flat.weight.grad.reshape(-1), atol=0, rtol=0)
93+
94+
95+
if __name__ == "__main__":
96+
absltest.main()

0 commit comments

Comments
 (0)