diff --git a/drivers/google/compute_util.go b/drivers/google/compute_util.go index d3bb691828ae840e01aed5ac747b2e0e7de839a9..07de0e21c210f312a66092e4864ff5af36158dda 100644 --- a/drivers/google/compute_util.go +++ b/drivers/google/compute_util.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "regexp" + "strconv" "strings" "time" @@ -41,6 +42,8 @@ type ComputeUtil struct { SwarmHost string openPorts []string minCPUPlatform string + accelerator string + maintenancePolicy string operationBackoffFactory *backoffFactory } @@ -84,9 +87,48 @@ func newComputeUtil(driver *Driver) (*ComputeUtil, error) { openPorts: driver.OpenPorts, operationBackoffFactory: driver.OperationBackoffFactory, minCPUPlatform: driver.MinCPUPlatform, + accelerator: driver.Accelerator, + maintenancePolicy: driver.MaintenancePolicy, }, nil } +func (c *ComputeUtil) acceleratorCountAndType() (int, string) { + if c.accelerator == "" { + return 0, "" + } + + split := strings.Split(strings.TrimSpace(c.accelerator), ",") + count := 1 + acceleratorType := "" + + for _, kvStr := range split { + kv := strings.Split(kvStr, "=") + + if len(kv) != 2 { + log.Infof("Invalid key/value parameter for accelerator: %s, ignoring", kvStr) + continue + } + + key, value := strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1]) + switch key { + case "count": + var err error + count, err = strconv.Atoi(value) + if err != nil { + log.Infof("Failed to parse %q as count, disabling accelerator", value) + return 0, "" + } + case "type": + acceleratorType = strings.TrimSpace(value) + default: + log.Infof("Invalid accelerator defined %q, should be count=N,type=type", c.accelerator) + return 0, "" + } + } + + return count, acceleratorType +} + func (c *ComputeUtil) diskName() string { return c.instanceName + "-disk" } @@ -292,6 +334,21 @@ func (c *ComputeUtil) createInstance(d *Driver) error { Metadata: metadata, } + if c.maintenancePolicy != "" { + instance.Scheduling.OnHostMaintenance = c.maintenancePolicy + } + + acceleratorCount, acceleratorType := c.acceleratorCountAndType() + + if acceleratorCount > 0 && len(acceleratorType) > 0 { + instance.GuestAccelerators = []*raw.AcceleratorConfig{ + { + AcceleratorCount: int64(acceleratorCount), + AcceleratorType: "https://www.googleapis.com/compute/v1/projects/" + c.project + "/zones/" + c.zone + "/acceleratorTypes/" + acceleratorType, + }, + } + } + if strings.Contains(c.subnetwork, "/subnetworks/") { instance.NetworkInterfaces[0].Subnetwork = c.subnetwork } else if c.subnetwork != "" { diff --git a/drivers/google/compute_util_test.go b/drivers/google/compute_util_test.go index 99360df39f9d6f8471e9e1d343cabf26b8b737bf..b396e312bcbb3f88153105196d614dfc7434b21d 100644 --- a/drivers/google/compute_util_test.go +++ b/drivers/google/compute_util_test.go @@ -407,3 +407,62 @@ func prepareMetadataFile(t *testing.T, key string, content string) *os.File { return file } + +func TestAccelerator(t *testing.T) { + tests := map[string]struct { + description string + computeUtil *ComputeUtil + expectedCount int + expectedType string + }{ + "unspecified": { + computeUtil: &ComputeUtil{}, + expectedCount: 0, + expectedType: "", + }, + "GPU type": { + computeUtil: &ComputeUtil{accelerator: "type=nvidia-tesla-p100"}, + expectedCount: 1, + expectedType: "nvidia-tesla-p100", + }, + "count and GPU type": { + computeUtil: &ComputeUtil{accelerator: "count=2,type=nvidia-tesla-p100"}, + expectedCount: 2, + expectedType: "nvidia-tesla-p100", + }, + "count and GPU type with whitespace": { + computeUtil: &ComputeUtil{accelerator: " count=2, type=nvidia-tesla-p100 "}, + expectedCount: 2, + expectedType: "nvidia-tesla-p100", + }, + "unknown key=value pair": { + computeUtil: &ComputeUtil{accelerator: "hello=world"}, + expectedCount: 0, + expectedType: "", + }, + "extraneous key=value pair": { + computeUtil: &ComputeUtil{accelerator: "count=2,type=nvidia-tesla-p100,5"}, + expectedCount: 2, + expectedType: "nvidia-tesla-p100", + }, + "invalid count": { + computeUtil: &ComputeUtil{accelerator: "count=ten,type=nvidia-tesla-p100"}, + expectedCount: 0, + expectedType: "", + }, + "blank GPU type": { + computeUtil: &ComputeUtil{accelerator: "count=10,"}, + expectedCount: 10, + expectedType: "", + }, + } + + for tn, tt := range tests { + t.Run(tn, func(t *testing.T) { + count, acceleratorType := tt.computeUtil.acceleratorCountAndType() + + assert.Equal(t, tt.expectedCount, count) + assert.Equal(t, tt.expectedType, acceleratorType) + }) + } +} diff --git a/drivers/google/google.go b/drivers/google/google.go index 28a628ba65ec4b9f59ef59861c703dd9336e19ac..0cb86e626f4255153cdd71fbc0d53c612fc3cb2a 100644 --- a/drivers/google/google.go +++ b/drivers/google/google.go @@ -61,22 +61,26 @@ type Driver struct { Labels []string Metadata metadataMap MetadataFromFile metadataMap + Accelerator string + MaintenancePolicy string OperationBackoffFactory *backoffFactory } const ( - defaultZone = "us-central1-a" - defaultUser = "docker-user" - defaultMachineType = "n1-standard-1" - defaultImageName = "ubuntu-os-cloud/global/images/ubuntu-1604-xenial-v20170721" - defaultServiceAccount = "default" - defaultScopes = "https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write" - defaultDiskType = "pd-standard" - defaultDiskSize = 10 - defaultNetwork = "default" - defaultSubnetwork = "" - defaultMinCPUPlatform = "" + defaultZone = "us-central1-a" + defaultUser = "docker-user" + defaultMachineType = "n1-standard-1" + defaultImageName = "ubuntu-os-cloud/global/images/ubuntu-1604-xenial-v20170721" + defaultServiceAccount = "default" + defaultScopes = "https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write" + defaultDiskType = "pd-standard" + defaultDiskSize = 10 + defaultNetwork = "default" + defaultSubnetwork = "" + defaultMinCPUPlatform = "" + defaultAccelerator = "" + defaultMaintenancePolicy = "" defaultGoogleOperationBackoffInitialInterval = 1 defaultGoogleOperationBackoffRandomizationFactor = "0.5" @@ -232,6 +236,18 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag { Name: "google-metadata-from-file", Usage: "Path to a file containing the metadata value inform of key=path/to/file. Use multiple times for multiple settings", }, + mcnflag.StringFlag{ + Name: "google-accelerator", + Usage: "Count and specific type of GPU accelerators (format: count=N,type=type) to attach to the instance, e.g. count=1,type=nvidia-tesla-p100", + EnvVar: "GOOGLE_ACCELERATOR", + Value: defaultAccelerator, + }, + mcnflag.StringFlag{ + Name: "google-maintenance-policy", + Usage: "Defines the maintenance behavior for this instance, e.g, MIGRATE or TERMINATE", + EnvVar: "GOOGLE_MAINTENANCE_POLICY", + Value: defaultMaintenancePolicy, + }, } } @@ -303,6 +319,8 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error { d.Labels = flags.StringSlice("google-label") d.Metadata = metadataMapFromStringSlice(flags.StringSlice("google-metadata")) d.MetadataFromFile = metadataMapFromStringSlice(flags.StringSlice("google-metadata-from-file")) + d.Accelerator = flags.String("google-accelerator") + d.MaintenancePolicy = flags.String("google-maintenance-policy") } d.SSHUser = flags.String("google-username") d.SSHPort = 22