Skip to content

Commit b4b1dfa

Browse files
committed
Add a hook to update nvidia params
If required, this hook creates a modified params file (with ModifyDeviceFiles: 0) in a tmpfs and mounts this over /proc/driver/nvidia/params. This prevents device node creation when running tools such as nvidia-smi. Signed-off-by: Evan Lezar <[email protected]>
1 parent 6463390 commit b4b1dfa

File tree

6 files changed

+355
-0
lines changed

6 files changed

+355
-0
lines changed

cmd/nvidia-cdi-hook/commands/commands.go

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod"
2323
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
2424
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache"
25+
nvidiaparams "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-nvidia-params"
2526
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2627
)
2728

@@ -32,5 +33,6 @@ func New(logger logger.Interface) []*cli.Command {
3233
ldcache.NewCommand(logger),
3334
symlinks.NewCommand(logger),
3435
chmod.NewCommand(logger),
36+
nvidiaparams.NewCommand(logger),
3537
}
3638
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//go:build linux
2+
// +build linux
3+
4+
/**
5+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
**/
19+
20+
package nvidiaparams
21+
22+
import (
23+
"fmt"
24+
25+
"golang.org/x/sys/unix"
26+
)
27+
28+
func createTmpFs(target string, size int) error {
29+
return unix.Mount("tmpfs", target, "tmpfs", 0, fmt.Sprintf("size=%d", size))
30+
}
31+
32+
func bindMountReadonly(source string, target string) error {
33+
return unix.Mount(source, target, "", unix.MS_BIND|unix.MS_RDONLY|unix.MS_NOSYMFOLLOW, "")
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//go:build !linux
2+
// +build !linux
3+
4+
/**
5+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
**/
19+
20+
package nvidiaparams
21+
22+
import (
23+
"fmt"
24+
)
25+
26+
func createTmpFs(target string, size int) error {
27+
return fmt.Errorf("not supported")
28+
}
29+
30+
func bindMountReadonly(source string, target string) error {
31+
return fmt.Errorf("not supported")
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
/**
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nvidiaparams
18+
19+
import (
20+
"bufio"
21+
"bytes"
22+
"errors"
23+
"fmt"
24+
"io"
25+
"os"
26+
"path/filepath"
27+
"strings"
28+
29+
"github.com/urfave/cli/v2"
30+
31+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
32+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
33+
)
34+
35+
const (
36+
nvidiaDriverParamsPath = "/proc/driver/nvidia/params"
37+
)
38+
39+
type command struct {
40+
logger logger.Interface
41+
}
42+
43+
type options struct {
44+
containerSpec string
45+
}
46+
47+
// NewCommand constructs an update-nvidia-params command with the specified logger
48+
func NewCommand(logger logger.Interface) *cli.Command {
49+
c := command{
50+
logger: logger,
51+
}
52+
return c.build()
53+
}
54+
55+
// build the update-nvidia-params command
56+
func (m command) build() *cli.Command {
57+
cfg := options{}
58+
59+
// Create the 'update-nvidia-params' command
60+
c := cli.Command{
61+
Name: "update-nvidia-params",
62+
Usage: "Update the /proc/driver/nvidia/params file in the container to disable device node modification.",
63+
Before: func(c *cli.Context) error {
64+
return m.validateFlags(c, &cfg)
65+
},
66+
Action: func(c *cli.Context) error {
67+
return m.run(c, &cfg)
68+
},
69+
}
70+
71+
c.Flags = []cli.Flag{
72+
&cli.StringFlag{
73+
Name: "container-spec",
74+
Hidden: true,
75+
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
76+
Destination: &cfg.containerSpec,
77+
},
78+
}
79+
80+
return &c
81+
}
82+
83+
func (m command) validateFlags(c *cli.Context, cfg *options) error {
84+
return nil
85+
}
86+
87+
func (m command) run(c *cli.Context, cfg *options) error {
88+
s, err := oci.LoadContainerState(cfg.containerSpec)
89+
if err != nil {
90+
return fmt.Errorf("failed to load container state: %v", err)
91+
}
92+
93+
containerRoot, err := s.GetContainerRoot()
94+
if err != nil {
95+
return fmt.Errorf("failed to determined container root: %v", err)
96+
}
97+
98+
return m.updateNvidiaParams(containerRoot)
99+
}
100+
101+
func (m command) updateNvidiaParams(containerRoot string) error {
102+
// TODO: Do we need to prefix the driver root?
103+
currentParamsFile, err := os.Open(nvidiaDriverParamsPath)
104+
if errors.Is(err, os.ErrNotExist) {
105+
return nil
106+
}
107+
if err != nil {
108+
return fmt.Errorf("failed to load params file: %w", err)
109+
}
110+
defer currentParamsFile.Close()
111+
112+
return m.updateNvidiaParamsFromReader(currentParamsFile, containerRoot)
113+
}
114+
115+
func (m command) updateNvidiaParamsFromReader(r io.Reader, containerRoot string) error {
116+
modifiedContents, err := m.getModifiedParamsFileContentsFromReader(r)
117+
if err != nil {
118+
return fmt.Errorf("failed to generate modified contents: %w", err)
119+
}
120+
if len(modifiedContents) == 0 {
121+
m.logger.Debugf("No modification required")
122+
return nil
123+
}
124+
return createParamsFileInContainer(containerRoot, modifiedContents)
125+
}
126+
127+
// getModifiedParamsFileContentsFromReader returns the contents of a modified params file from the specified reader.
128+
func (m command) getModifiedParamsFileContentsFromReader(r io.Reader) ([]byte, error) {
129+
var modified bytes.Buffer
130+
scanner := bufio.NewScanner(r)
131+
132+
var requiresModification bool
133+
for scanner.Scan() {
134+
line := scanner.Text()
135+
if strings.HasPrefix(line, "ModifyDeviceFiles: ") {
136+
if line == "ModifyDeviceFiles: 0" {
137+
m.logger.Debugf("Device node modification is already disabled")
138+
return nil, nil
139+
}
140+
if line == "ModifyDeviceFiles: 1" {
141+
line = "ModifyDeviceFiles: 0"
142+
requiresModification = true
143+
}
144+
}
145+
if _, err := modified.WriteString(line + "\n"); err != nil {
146+
return nil, fmt.Errorf("failed to create output buffer: %w", err)
147+
}
148+
}
149+
if err := scanner.Err(); err != nil {
150+
return nil, fmt.Errorf("failed to read params file: %w", err)
151+
}
152+
153+
if !requiresModification {
154+
return nil, nil
155+
}
156+
157+
return modified.Bytes(), nil
158+
}
159+
160+
func createParamsFileInContainer(containerRoot string, contents []byte) error {
161+
if len(contents) == 0 {
162+
return nil
163+
}
164+
165+
tempParamsFileName, err := createFileInTempfs("nvct-params", contents, 0o444)
166+
if err != nil {
167+
return fmt.Errorf("failed to create temporary file: %w", err)
168+
}
169+
170+
if err := bindMountReadonly(tempParamsFileName, filepath.Join(containerRoot, nvidiaDriverParamsPath)); err != nil {
171+
return fmt.Errorf("failed to create temporary parms file mount: %w", err)
172+
}
173+
174+
return nil
175+
}
176+
177+
// createFileInTempfs creates a file with the specified name, contents, and mode in a tmpfs.
178+
// A tmpfs is created at /tmp/nvct-emtpy-dir* with a size sufficient for the specified contents.
179+
func createFileInTempfs(name string, contents []byte, mode os.FileMode) (string, error) {
180+
tmpRoot, err := os.MkdirTemp("", "nvct-empty-dir*")
181+
if err != nil {
182+
return "", fmt.Errorf("failed to create temporary folder: %w", err)
183+
}
184+
if err := createTmpFs(tmpRoot, len(contents)); err != nil {
185+
return "", fmt.Errorf("failed to create tmpfs mount for params file: %w", err)
186+
}
187+
188+
filename := filepath.Join(tmpRoot, name)
189+
fileInTempfs, err := os.Create(filename)
190+
if err != nil {
191+
return "", fmt.Errorf("failed to create temporary params file: %w", err)
192+
}
193+
defer fileInTempfs.Close()
194+
195+
if _, err := fileInTempfs.Write(contents); err != nil {
196+
return "", fmt.Errorf("failed to write temporary params file: %w", err)
197+
}
198+
199+
if err := fileInTempfs.Chmod(mode); err != nil {
200+
return "", fmt.Errorf("failed to set permissions on temporary params file: %w", err)
201+
}
202+
return filename, nil
203+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package nvidiaparams
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
testlog "github.com/sirupsen/logrus/hooks/test"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestGetModifiedParamsFileContentsFromReader(t *testing.T) {
12+
logger, _ := testlog.NewNullLogger()
13+
testCases := map[string]struct {
14+
contents []byte
15+
expectedError error
16+
expectedContents []byte
17+
}{
18+
"no contents": {
19+
contents: nil,
20+
expectedError: nil,
21+
expectedContents: nil,
22+
},
23+
"other contents are ignored": {
24+
contents: []byte(`# Some other content
25+
that we don't care about
26+
`),
27+
expectedError: nil,
28+
expectedContents: nil,
29+
},
30+
"already zero requires no modification": {
31+
contents: []byte("ModifyDeviceFiles: 0"),
32+
expectedError: nil,
33+
expectedContents: nil,
34+
},
35+
"leading spaces require no modification": {
36+
contents: []byte(" ModifyDeviceFiles: 1"),
37+
},
38+
"Trailing spaces require no modification": {
39+
contents: []byte("ModifyDeviceFiles: 1 "),
40+
},
41+
"Not 1 require no modification": {
42+
contents: []byte("ModifyDeviceFiles: 11"),
43+
},
44+
"single line requires modification": {
45+
contents: []byte("ModifyDeviceFiles: 1"),
46+
expectedError: nil,
47+
expectedContents: []byte("ModifyDeviceFiles: 0\n"),
48+
},
49+
"single line with trailing newline requires modification": {
50+
contents: []byte("ModifyDeviceFiles: 1\n"),
51+
expectedError: nil,
52+
expectedContents: []byte("ModifyDeviceFiles: 0\n"),
53+
},
54+
"other content is maintained": {
55+
contents: []byte(`ModifyDeviceFiles: 1
56+
other content
57+
that
58+
is maintained`),
59+
expectedError: nil,
60+
expectedContents: []byte(`ModifyDeviceFiles: 0
61+
other content
62+
that
63+
is maintained
64+
`),
65+
},
66+
}
67+
68+
for description, tc := range testCases {
69+
t.Run(description, func(t *testing.T) {
70+
c := command{
71+
logger: logger,
72+
}
73+
contents, err := c.getModifiedParamsFileContentsFromReader(bytes.NewReader(tc.contents))
74+
require.EqualValues(t, tc.expectedError, err)
75+
require.EqualValues(t, string(tc.expectedContents), string(contents))
76+
})
77+
}
78+
79+
}

cmd/nvidia-ctk-installer/container/toolkit/toolkit_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ containerEdits:
8787
- /lib/x86_64-linux-gnu
8888
hookName: createContainer
8989
path: {{ .toolkitRoot }}/nvidia-cdi-hook
90+
- args:
91+
- nvidia-cdi-hook
92+
- update-nvidia-params
93+
hookName: createContainer
94+
path: {{ .toolkitRoot }}/nvidia-cdi-hook
9095
mounts:
9196
- containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
9297
hostPath: /host/driver/root/lib/x86_64-linux-gnu/libcuda.so.999.88.77

0 commit comments

Comments
 (0)