From 630eb0abce7fd0f2353dbaa1dbd240b11fd0a96c Mon Sep 17 00:00:00 2001 From: Sailaxman Kumar <sailaxman.kumar59@gcpsandpit.auspost> Date: Tue, 13 Jul 2021 10:09:05 +0000 Subject: [PATCH] add --google-skip-firewall-create --- drivers/google/compute_util.go | 7 ++++ drivers/google/compute_util_test.go | 62 ++++++++++++++++++++++++++++- drivers/google/google.go | 7 ++++ 3 files changed, 75 insertions(+), 1 deletion(-) diff --git a/drivers/google/compute_util.go b/drivers/google/compute_util.go index be3b4da3..56c860e6 100644 --- a/drivers/google/compute_util.go +++ b/drivers/google/compute_util.go @@ -44,6 +44,7 @@ type ComputeUtil struct { minCPUPlatform string accelerator string maintenancePolicy string + skipFirewall bool operationBackoffFactory *backoffFactory } @@ -89,6 +90,7 @@ func newComputeUtil(driver *Driver) (*ComputeUtil, error) { minCPUPlatform: driver.MinCPUPlatform, accelerator: driver.Accelerator, maintenancePolicy: driver.MaintenancePolicy, + skipFirewall: driver.SkipFirewall, }, nil } @@ -230,6 +232,11 @@ func (c *ComputeUtil) portsUsed() ([]string, error) { // openFirewallPorts configures the firewall to open docker and swarm ports. func (c *ComputeUtil) openFirewallPorts(d *Driver) error { + if c.skipFirewall { + log.Infof("Skipping opening firewall ports") + return nil + } + log.Infof("Opening firewall ports") create := false diff --git a/drivers/google/compute_util_test.go b/drivers/google/compute_util_test.go index b396e312..91590172 100644 --- a/drivers/google/compute_util_test.go +++ b/drivers/google/compute_util_test.go @@ -1,15 +1,21 @@ package google import ( + "context" "errors" "fmt" + "io" "io/ioutil" + "net/http" + "net/http/httptest" "os" "testing" "time" "github.com/stretchr/testify/assert" raw "google.golang.org/api/compute/v1" + "google.golang.org/api/googleapi" + "google.golang.org/api/option" ) func TestDefaultTag(t *testing.T) { @@ -75,6 +81,61 @@ func TestLabels(t *testing.T) { } } +func TestOpenFirewallPorts(t *testing.T) { + tests := map[string]struct { + skipFirewall bool + mockResponse http.HandlerFunc + }{ + "skip firewall": { + skipFirewall: true, + mockResponse: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + }), + }, + "firewall rules exists": { + skipFirewall: false, + mockResponse: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + firewall := raw.Firewall{ + + Allowed: []*raw.FirewallAllowed{ + { + IPProtocol: "tcp", + Ports: []string{"22", "2376"}, + }, + }, + } + var body io.Reader = nil + body, err := googleapi.WithoutDataWrapper.JSONReader(firewall) + if err != nil { + t.Fatal(err) + } + fmt.Fprint(w, body) + }), + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + srv := httptest.NewServer(tt.mockResponse) + defer srv.Close() + + svc, err := raw.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(srv.URL)) + if err != nil { + t.Fatal(err) + } + + computeUtil := ComputeUtil{ + skipFirewall: tt.skipFirewall, + service: svc, + } + + driver := &Driver{} + + err = computeUtil.openFirewallPorts(driver) + assert.NoError(t, err) + }) + } +} + func TestPortsUsed(t *testing.T) { var tests = []struct { description string @@ -90,7 +151,6 @@ func TestPortsUsed(t *testing.T) { for _, test := range tests { ports, err := test.computeUtil.portsUsed() - assert.Equal(t, test.expectedPorts, ports) assert.Equal(t, test.expectedError, err) } diff --git a/drivers/google/google.go b/drivers/google/google.go index 0cb86e62..4e18c8b4 100644 --- a/drivers/google/google.go +++ b/drivers/google/google.go @@ -63,6 +63,7 @@ type Driver struct { MetadataFromFile metadataMap Accelerator string MaintenancePolicy string + SkipFirewall bool OperationBackoffFactory *backoffFactory } @@ -248,6 +249,11 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag { EnvVar: "GOOGLE_MAINTENANCE_POLICY", Value: defaultMaintenancePolicy, }, + mcnflag.BoolFlag{ + Name: "google-skip-firewall-create", + Usage: "Skip firewall setup", + EnvVar: "GOOGLE_SKIP_FIREWALL_CREATE", + }, } } @@ -321,6 +327,7 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error { d.MetadataFromFile = metadataMapFromStringSlice(flags.StringSlice("google-metadata-from-file")) d.Accelerator = flags.String("google-accelerator") d.MaintenancePolicy = flags.String("google-maintenance-policy") + d.SkipFirewall = flags.Bool("google-skip-firewall-create") } d.SSHUser = flags.String("google-username") d.SSHPort = 22 -- GitLab