Skip to content
Snippets Groups Projects
Commit 555176e7 authored by Stan Hu's avatar Stan Hu Committed by Steve Azzopardi
Browse files

Add support for using GPUs in Google Compute Engine

Add support for configuring Google maintenance policy

This is needed since GPU-accelerated instances cannot have a MIGRATE
maintenance policy:
https://cloud.google.com/compute/docs/instances/live-migration

Relates to https://gitlab.com/gitlab-org/ci-cd/docker-machine/-/issues/34
parent 9f6c4a9a
Branches
Tags
No related merge requests found
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time" "time"
...@@ -41,6 +42,8 @@ type ComputeUtil struct { ...@@ -41,6 +42,8 @@ type ComputeUtil struct {
SwarmHost string SwarmHost string
openPorts []string openPorts []string
minCPUPlatform string minCPUPlatform string
accelerator string
maintenancePolicy string
operationBackoffFactory *backoffFactory operationBackoffFactory *backoffFactory
} }
...@@ -84,9 +87,48 @@ func newComputeUtil(driver *Driver) (*ComputeUtil, error) { ...@@ -84,9 +87,48 @@ func newComputeUtil(driver *Driver) (*ComputeUtil, error) {
openPorts: driver.OpenPorts, openPorts: driver.OpenPorts,
operationBackoffFactory: driver.OperationBackoffFactory, operationBackoffFactory: driver.OperationBackoffFactory,
minCPUPlatform: driver.MinCPUPlatform, minCPUPlatform: driver.MinCPUPlatform,
accelerator: driver.Accelerator,
maintenancePolicy: driver.MaintenancePolicy,
}, nil }, nil
} }
func (c *ComputeUtil) acceleratorCountAndType() (int, string) {
if c.accelerator == "" {
return 0, ""
}
split := strings.Split(strings.TrimSpace(c.accelerator), ",")
count := 1
acceleratorType := ""
for _, kvStr := range split {
kv := strings.Split(kvStr, "=")
if len(kv) != 2 {
log.Infof("Invalid key/value parameter for accelerator: %s, ignoring", kvStr)
continue
}
key, value := strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1])
switch key {
case "count":
var err error
count, err = strconv.Atoi(value)
if err != nil {
log.Infof("Failed to parse %q as count, disabling accelerator", value)
return 0, ""
}
case "type":
acceleratorType = strings.TrimSpace(value)
default:
log.Infof("Invalid accelerator defined %q, should be count=N,type=type", c.accelerator)
return 0, ""
}
}
return count, acceleratorType
}
func (c *ComputeUtil) diskName() string { func (c *ComputeUtil) diskName() string {
return c.instanceName + "-disk" return c.instanceName + "-disk"
} }
...@@ -292,6 +334,21 @@ func (c *ComputeUtil) createInstance(d *Driver) error { ...@@ -292,6 +334,21 @@ func (c *ComputeUtil) createInstance(d *Driver) error {
Metadata: metadata, Metadata: metadata,
} }
if c.maintenancePolicy != "" {
instance.Scheduling.OnHostMaintenance = c.maintenancePolicy
}
acceleratorCount, acceleratorType := c.acceleratorCountAndType()
if acceleratorCount > 0 && len(acceleratorType) > 0 {
instance.GuestAccelerators = []*raw.AcceleratorConfig{
{
AcceleratorCount: int64(acceleratorCount),
AcceleratorType: "https://www.googleapis.com/compute/v1/projects/" + c.project + "/zones/" + c.zone + "/acceleratorTypes/" + acceleratorType,
},
}
}
if strings.Contains(c.subnetwork, "/subnetworks/") { if strings.Contains(c.subnetwork, "/subnetworks/") {
instance.NetworkInterfaces[0].Subnetwork = c.subnetwork instance.NetworkInterfaces[0].Subnetwork = c.subnetwork
} else if c.subnetwork != "" { } else if c.subnetwork != "" {
......
...@@ -407,3 +407,62 @@ func prepareMetadataFile(t *testing.T, key string, content string) *os.File { ...@@ -407,3 +407,62 @@ func prepareMetadataFile(t *testing.T, key string, content string) *os.File {
return file return file
} }
func TestAccelerator(t *testing.T) {
tests := map[string]struct {
description string
computeUtil *ComputeUtil
expectedCount int
expectedType string
}{
"unspecified": {
computeUtil: &ComputeUtil{},
expectedCount: 0,
expectedType: "",
},
"GPU type": {
computeUtil: &ComputeUtil{accelerator: "type=nvidia-tesla-p100"},
expectedCount: 1,
expectedType: "nvidia-tesla-p100",
},
"count and GPU type": {
computeUtil: &ComputeUtil{accelerator: "count=2,type=nvidia-tesla-p100"},
expectedCount: 2,
expectedType: "nvidia-tesla-p100",
},
"count and GPU type with whitespace": {
computeUtil: &ComputeUtil{accelerator: " count=2, type=nvidia-tesla-p100 "},
expectedCount: 2,
expectedType: "nvidia-tesla-p100",
},
"unknown key=value pair": {
computeUtil: &ComputeUtil{accelerator: "hello=world"},
expectedCount: 0,
expectedType: "",
},
"extraneous key=value pair": {
computeUtil: &ComputeUtil{accelerator: "count=2,type=nvidia-tesla-p100,5"},
expectedCount: 2,
expectedType: "nvidia-tesla-p100",
},
"invalid count": {
computeUtil: &ComputeUtil{accelerator: "count=ten,type=nvidia-tesla-p100"},
expectedCount: 0,
expectedType: "",
},
"blank GPU type": {
computeUtil: &ComputeUtil{accelerator: "count=10,"},
expectedCount: 10,
expectedType: "",
},
}
for tn, tt := range tests {
t.Run(tn, func(t *testing.T) {
count, acceleratorType := tt.computeUtil.acceleratorCountAndType()
assert.Equal(t, tt.expectedCount, count)
assert.Equal(t, tt.expectedType, acceleratorType)
})
}
}
...@@ -61,6 +61,8 @@ type Driver struct { ...@@ -61,6 +61,8 @@ type Driver struct {
Labels []string Labels []string
Metadata metadataMap Metadata metadataMap
MetadataFromFile metadataMap MetadataFromFile metadataMap
Accelerator string
MaintenancePolicy string
OperationBackoffFactory *backoffFactory OperationBackoffFactory *backoffFactory
} }
...@@ -77,6 +79,8 @@ const ( ...@@ -77,6 +79,8 @@ const (
defaultNetwork = "default" defaultNetwork = "default"
defaultSubnetwork = "" defaultSubnetwork = ""
defaultMinCPUPlatform = "" defaultMinCPUPlatform = ""
defaultAccelerator = ""
defaultMaintenancePolicy = ""
defaultGoogleOperationBackoffInitialInterval = 1 defaultGoogleOperationBackoffInitialInterval = 1
defaultGoogleOperationBackoffRandomizationFactor = "0.5" defaultGoogleOperationBackoffRandomizationFactor = "0.5"
...@@ -232,6 +236,18 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag { ...@@ -232,6 +236,18 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag {
Name: "google-metadata-from-file", Name: "google-metadata-from-file",
Usage: "Path to a file containing the metadata value inform of key=path/to/file. Use multiple times for multiple settings", Usage: "Path to a file containing the metadata value inform of key=path/to/file. Use multiple times for multiple settings",
}, },
mcnflag.StringFlag{
Name: "google-accelerator",
Usage: "Count and specific type of GPU accelerators (format: count=N,type=type) to attach to the instance, e.g. count=1,type=nvidia-tesla-p100",
EnvVar: "GOOGLE_ACCELERATOR",
Value: defaultAccelerator,
},
mcnflag.StringFlag{
Name: "google-maintenance-policy",
Usage: "Defines the maintenance behavior for this instance, e.g, MIGRATE or TERMINATE",
EnvVar: "GOOGLE_MAINTENANCE_POLICY",
Value: defaultMaintenancePolicy,
},
} }
} }
...@@ -303,6 +319,8 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error { ...@@ -303,6 +319,8 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error {
d.Labels = flags.StringSlice("google-label") d.Labels = flags.StringSlice("google-label")
d.Metadata = metadataMapFromStringSlice(flags.StringSlice("google-metadata")) d.Metadata = metadataMapFromStringSlice(flags.StringSlice("google-metadata"))
d.MetadataFromFile = metadataMapFromStringSlice(flags.StringSlice("google-metadata-from-file")) d.MetadataFromFile = metadataMapFromStringSlice(flags.StringSlice("google-metadata-from-file"))
d.Accelerator = flags.String("google-accelerator")
d.MaintenancePolicy = flags.String("google-maintenance-policy")
} }
d.SSHUser = flags.String("google-username") d.SSHUser = flags.String("google-username")
d.SSHPort = 22 d.SSHPort = 22
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment