diff --git a/drivers/google/compute_util.go b/drivers/google/compute_util.go index be3b4da3116d6dd7d3578c171268872bfa226bd8..56c860e6494d8de9af4137b308a09fbeff43a490 100644 --- a/drivers/google/compute_util.go +++ b/drivers/google/compute_util.go @@ -44,6 +44,7 @@ type ComputeUtil struct { minCPUPlatform string accelerator string maintenancePolicy string + skipFirewall bool operationBackoffFactory *backoffFactory } @@ -89,6 +90,7 @@ func newComputeUtil(driver *Driver) (*ComputeUtil, error) { minCPUPlatform: driver.MinCPUPlatform, accelerator: driver.Accelerator, maintenancePolicy: driver.MaintenancePolicy, + skipFirewall: driver.SkipFirewall, }, nil } @@ -230,6 +232,11 @@ func (c *ComputeUtil) portsUsed() ([]string, error) { // openFirewallPorts configures the firewall to open docker and swarm ports. func (c *ComputeUtil) openFirewallPorts(d *Driver) error { + if c.skipFirewall { + log.Infof("Skipping opening firewall ports") + return nil + } + log.Infof("Opening firewall ports") create := false diff --git a/drivers/google/compute_util_test.go b/drivers/google/compute_util_test.go index b396e312bcbb3f88153105196d614dfc7434b21d..915901723dbe4c0fe9a56ef270c3e004fd19f844 100644 --- a/drivers/google/compute_util_test.go +++ b/drivers/google/compute_util_test.go @@ -1,15 +1,21 @@ package google import ( + "context" "errors" "fmt" + "io" "io/ioutil" + "net/http" + "net/http/httptest" "os" "testing" "time" "github.com/stretchr/testify/assert" raw "google.golang.org/api/compute/v1" + "google.golang.org/api/googleapi" + "google.golang.org/api/option" ) func TestDefaultTag(t *testing.T) { @@ -75,6 +81,61 @@ func TestLabels(t *testing.T) { } } +func TestOpenFirewallPorts(t *testing.T) { + tests := map[string]struct { + skipFirewall bool + mockResponse http.HandlerFunc + }{ + "skip firewall": { + skipFirewall: true, + mockResponse: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + }), + }, + "firewall rules exists": { + skipFirewall: false, + mockResponse: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + firewall := raw.Firewall{ + + Allowed: []*raw.FirewallAllowed{ + { + IPProtocol: "tcp", + Ports: []string{"22", "2376"}, + }, + }, + } + var body io.Reader = nil + body, err := googleapi.WithoutDataWrapper.JSONReader(firewall) + if err != nil { + t.Fatal(err) + } + fmt.Fprint(w, body) + }), + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + srv := httptest.NewServer(tt.mockResponse) + defer srv.Close() + + svc, err := raw.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(srv.URL)) + if err != nil { + t.Fatal(err) + } + + computeUtil := ComputeUtil{ + skipFirewall: tt.skipFirewall, + service: svc, + } + + driver := &Driver{} + + err = computeUtil.openFirewallPorts(driver) + assert.NoError(t, err) + }) + } +} + func TestPortsUsed(t *testing.T) { var tests = []struct { description string @@ -90,7 +151,6 @@ func TestPortsUsed(t *testing.T) { for _, test := range tests { ports, err := test.computeUtil.portsUsed() - assert.Equal(t, test.expectedPorts, ports) assert.Equal(t, test.expectedError, err) } diff --git a/drivers/google/google.go b/drivers/google/google.go index 0cb86e626f4255153cdd71fbc0d53c612fc3cb2a..4e18c8b4d4cd93ce15afa386aa2b50ef34af78ba 100644 --- a/drivers/google/google.go +++ b/drivers/google/google.go @@ -63,6 +63,7 @@ type Driver struct { MetadataFromFile metadataMap Accelerator string MaintenancePolicy string + SkipFirewall bool OperationBackoffFactory *backoffFactory } @@ -248,6 +249,11 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag { EnvVar: "GOOGLE_MAINTENANCE_POLICY", Value: defaultMaintenancePolicy, }, + mcnflag.BoolFlag{ + Name: "google-skip-firewall-create", + Usage: "Skip firewall setup", + EnvVar: "GOOGLE_SKIP_FIREWALL_CREATE", + }, } } @@ -321,6 +327,7 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error { d.MetadataFromFile = metadataMapFromStringSlice(flags.StringSlice("google-metadata-from-file")) d.Accelerator = flags.String("google-accelerator") d.MaintenancePolicy = flags.String("google-maintenance-policy") + d.SkipFirewall = flags.Bool("google-skip-firewall-create") } d.SSHUser = flags.String("google-username") d.SSHPort = 22