diff --git a/cmd/cloud_orchestrator/main.go b/cmd/cloud_orchestrator/main.go index ed5c7f26..7f6e208b 100644 --- a/cmd/cloud_orchestrator/main.go +++ b/cmd/cloud_orchestrator/main.go @@ -68,7 +68,10 @@ func LoadInstanceManager(config *config.Config) instances.Manager { if err != nil { log.Fatal("Failed to get docker client: ", err) } - im = instances.NewDockerInstanceManager(config.InstanceManager, cli) + im, err = instances.NewDockerInstanceManager(config.InstanceManager, cli) + if err != nil { + log.Fatal("Failed to create Docker Instance Manager: ", err) + } default: log.Fatal("Unknown Instance Manager type: ", config.InstanceManager.Type) } diff --git a/pkg/app/instances/docker.go b/pkg/app/instances/docker.go index 849b8737..0708e31f 100644 --- a/pkg/app/instances/docker.go +++ b/pkg/app/instances/docker.go @@ -41,9 +41,12 @@ const DockerIMType IMType = "docker" type DockerIMConfig struct { DockerImageName string + GpuManufacturer string HostOrchestratorPort int } +const gpuManufacturerNvidia = "nvidia" + const ( dockerLabelCreatedBy = "created_by" dockerLabelKeyManagedBy = "managed_by" @@ -66,11 +69,14 @@ const ( DeleteHostOPType OPType = "deletehost" ) -func NewDockerInstanceManager(cfg Config, cli *client.Client) *DockerInstanceManager { +func NewDockerInstanceManager(cfg Config, cli *client.Client) (*DockerInstanceManager, error) { + if cfg.Docker.GpuManufacturer != "" && cfg.Docker.GpuManufacturer != gpuManufacturerNvidia { + return nil, fmt.Errorf("unsupported GPU manufacturer: %q", cfg.Docker.GpuManufacturer) + } return &DockerInstanceManager{ Config: cfg, Client: cli, - } + }, nil } func (m *DockerInstanceManager) ListZones() (*apiv1.ListZonesResponse, error) { @@ -371,6 +377,9 @@ func (m *DockerInstanceManager) createDockerContainer(ctx context.Context, user Tty: true, Labels: dockerLabelsDict(user), } + if m.Config.Docker.GpuManufacturer == gpuManufacturerNvidia { + config.Env = []string{"NVIDIA_DRIVER_CAPABILITIES=all"} + } hostConfig := &container.HostConfig{ Mounts: []mount.Mount{ { @@ -381,6 +390,17 @@ func (m *DockerInstanceManager) createDockerContainer(ctx context.Context, user }, Privileged: true, } + if m.Config.Docker.GpuManufacturer == gpuManufacturerNvidia { + hostConfig.Resources = container.Resources{ + DeviceRequests: []container.DeviceRequest{ + { + Count: -1, + Capabilities: [][]string{{"gpu"}}, + }, + }, + } + hostConfig.Runtime = "nvidia" + } createRes, err := m.Client.ContainerCreate(ctx, config, hostConfig, nil, nil, "") if err != nil { return "", fmt.Errorf("failed to create docker container: %w", err)