Skip to content

Commit

Permalink
Refactor to allow for ipv6 addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
carlpartridge committed Jan 22, 2025
1 parent 6e47c19 commit 3a0d368
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 141 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/waf-sync-lambda-prod-deploy.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: WAFSync Lambda prod deploy
name: WAF Sync Lambda prod deploy

on:
workflow_dispatch:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/waf-sync-lambda-test-deploy.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: WAFSync Lambda test deploy
name: WAF Sync Lambda test deploy

on:
workflow_call:
Expand Down
1 change: 1 addition & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ services:
- DATABASE_URL=postgresql://postgres:toor@db:5432/bcda?sslmode=disable
- BCDA_SSAS_CLIENT_ID=fake-client-id
- BCDA_SSAS_SECRET=fake-secret
- ENV=local
- DEPLOYMENT_TARGET=local
- SSAS_ADMIN_SIGNING_KEY_PATH=../../../shared_files/ssas/admin_test_signing_key.pem
- SSAS_PUBLIC_SIGNING_KEY_PATH=../../../shared_files/ssas/public_test_signing_key.pem
Expand Down
27 changes: 1 addition & 26 deletions lambda/wafsync/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package main

import (
"fmt"
"os"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/wafv2"
"github.com/aws/aws-sdk-go/service/wafv2/wafv2iface"
log "github.com/sirupsen/logrus"
Expand All @@ -19,30 +17,7 @@ type Parameters struct {
Addresses []string
}

var createSession = func() (*session.Session, error) {
sess := session.Must(session.NewSession())

var err error
if isTesting {
sess, err = session.NewSessionWithOptions(session.Options{
Profile: "default",
Config: aws.Config{
Region: aws.String("us-east-1"),
S3ForcePathStyle: aws.Bool(true),
Endpoint: aws.String("http://localhost:4566"),
},
})
}
if err != nil {
return nil, err
}

return sess, nil
}

func fetchAndUpdateIpAddresses(waf wafv2iface.WAFV2API, ipAddresses []string) ([]string, error) {
ipSetName := fmt.Sprintf("bcda-%s-api-customers", os.Getenv("ENV"))

func fetchAndUpdateIpAddresses(waf wafv2iface.WAFV2API, ipSetName string, ipAddresses []string) ([]string, error) {
listParams := &wafv2.ListIPSetsInput{
Scope: aws.String("REGIONAL"),
}
Expand Down
2 changes: 1 addition & 1 deletion lambda/wafsync/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestNewLocalSession(t *testing.T) {
func TestFetchAndUpdateIpAddresses(t *testing.T) {
mock := &mockWAFV2Client{}

addresses, err := fetchAndUpdateIpAddresses(mock, []string{"127.0.0.1/32", "127.0.0.2/32"})
addresses, err := fetchAndUpdateIpAddresses(mock, "test-ip-set", []string{"127.0.0.1/32", "127.0.0.2/32"})

assert.Nil(t, err)
assert.Contains(t, addresses, "127.0.0.1/32")
Expand Down
25 changes: 16 additions & 9 deletions lambda/wafsync/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ type PgxConnection interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
Ping(context.Context) error
Prepare(context.Context, string, string) (*pgconn.StatementDescription, error)
Close(context.Context) error
}

func getValidIPAddresses(ctx context.Context, conn PgxConnection) ([]string, error) {
defer conn.Close(ctx)

func getValidIPAddresses(ctx context.Context, conn PgxConnection) ([]string, []string, error) {
query := `
SELECT DISTINCT ips.address FROM ips
WHERE deleted_at IS NULL
Expand All @@ -44,12 +40,13 @@ func getValidIPAddresses(ctx context.Context, conn PgxConnection) ([]string, err
rows, err := conn.Query(ctx, query)
if err != nil {
log.Errorf("Error running query: %+v", err)
return nil, err
return nil, nil, err
}

// count seems to only be used to log num of rows for debugging
count := 0
ipAddresses := []string{}
ipv6Addresses := []string{}
defer rows.Close()

for rows.Next() {
Expand All @@ -58,25 +55,35 @@ func getValidIPAddresses(ctx context.Context, conn PgxConnection) ([]string, err
err = rows.Scan(&ip)
if err != nil {
log.Errorf("Scan error: %+v", err)
return nil, err
return nil, nil, err
}

count += 1
if count%10000 == 0 {
log.Infof("Read %d rows", count)
}

ipAddresses = append(ipAddresses, ip.String()+"/32")
// check if ip address is IPv4 or IPv6
if ip.To4() != nil {
ipAddresses = append(ipAddresses, ip.String()+"/32")
} else {
ipv6Addresses = append(ipv6Addresses, ip.String()+"/128")
}

}

log.WithField("num_rows_scanned", count).Info("Successfully retrieved valid IP addresses")

return ipAddresses, nil
return ipAddresses, ipv6Addresses, nil
}

func getDBURL() (string, error) {
env := conf.GetEnv("ENV")

if env == "local" {
return conf.GetEnv("DATABASE_URL"), nil
}

bcdaSession, err := bcdaaws.NewSession("", os.Getenv("LOCAL_STACK_ENDPOINT"))
if err != nil {
return "", err
Expand Down
161 changes: 88 additions & 73 deletions lambda/wafsync/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package main

import (
"context"
"os"
"slices"
"errors"
"testing"

"github.com/google/uuid"
Expand All @@ -20,82 +19,98 @@ func TestGetValidIPAddresses(t *testing.T) {
defer mock.Close(ctx)

// the column ips.address returns type net.IP which needs to be handled here like a byte array
rows := mock.NewRows([]string{"address"}).AddRow([]byte("127.0.0.1"))
// due to pgxmock being unhappy with any other approach :(
rows := mock.NewRows([]string{"address"}).AddRow([]byte("127.0.0.1")).AddRow([]byte("1:2:3:4:5:6:7:8"))
mock.ExpectQuery("^SELECT DISTINCT ips.address FROM ips WHERE (.+)$").WillReturnRows(rows)

addresses, err := getValidIPAddresses(ctx, mock)
_, ipv6Addresses, err := getValidIPAddresses(ctx, mock)
assert.Nil(t, err)
// verifying on length of return as byte array from above gets muddled
assert.Len(t, addresses, 1)
assert.Len(t, ipv6Addresses, 2)
}

func TestGetValidIPAddressesFailure(t *testing.T) {
ctx := context.Background()
mock, err := pgxmock.NewConn()
assert.Nil(t, err)
defer mock.Close(ctx)

mock.ExpectQuery("^SELECT DISTINCT ips.address FROM ips WHERE (.+)$").WillReturnError(errors.New("test error"))

_, _, err = getValidIPAddresses(ctx, mock)
assert.ErrorContains(t, err, "test error")
}

func TestGetValidIPAddresses_Integration(t *testing.T) {
// only run actual DB testing in lower envs
if slices.Contains([]string{"local", "dev", "test"}, os.Getenv("ENV")) {
// insert valid and invalid ip addresses into actual DB
dbURL, err := getDBURL()
assert.Nil(t, err)

ctx := context.Background()

conn, err := pgx.Connect(ctx, dbURL)
assert.Nil(t, err)
defer conn.Close(ctx)

var validGroupID, invalidGroupID, validSystemID, invalidSystemID, validSystemID_invalidGroup, validSystemID_invalidSecret, validSystemID_invalidSecret_PastUpdated, secret1, secret2, secret3, ips1, ips2, ips3, ips4, ips5, ips6 int
err = conn.QueryRow(ctx, `INSERT INTO groups (group_id) VALUES($1) RETURNING id;`, uuid.New().String()).Scan(&validGroupID)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO groups (group_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, uuid.New().String()).Scan(&invalidGroupID)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, validGroupID).Scan(&invalidSystemID)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, invalidGroupID).Scan(&validSystemID_invalidGroup)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID_invalidSecret)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID_invalidSecret_PastUpdated)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO secrets (system_id, updated_at) VALUES($1, NOW()) RETURNING id;`, validSystemID).Scan(&secret1)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO secrets (system_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, validSystemID_invalidSecret).Scan(&secret2)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO secrets (system_id, updated_at) VALUES($1, NOW() - INTERVAL '100 DAY') RETURNING id;`, validSystemID_invalidSecret_PastUpdated).Scan(&secret3)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES('127.0.0.1', $1) RETURNING id;`, validSystemID).Scan(&ips1)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES('127.0.0.2', $1) RETURNING id;`, validSystemID).Scan(&ips2)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.3', $1, NOW()) RETURNING id;`, invalidSystemID).Scan(&ips3)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.4', $1, NOW()) RETURNING id;`, validSystemID_invalidGroup).Scan(&ips4)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.5', $1, NOW()) RETURNING id;`, validSystemID_invalidSecret).Scan(&ips5)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.6', $1, NOW()) RETURNING id;`, validSystemID_invalidSecret_PastUpdated).Scan(&ips6)
assert.Nil(t, err)

// execute
addresses, err := getValidIPAddresses(ctx, conn)

// verify
assert.Nil(t, err)
assert.Contains(t, addresses, "127.0.0.1/32")
assert.Contains(t, addresses, "127.0.0.2/32")
assert.NotContains(t, addresses, "127.0.0.3/32")
assert.NotContains(t, addresses, "127.0.0.4/32")
assert.NotContains(t, addresses, "127.0.0.5/32")
assert.NotContains(t, addresses, "127.0.0.6/32")

// cleanup
_, err = conn.Exec(ctx, `DELETE FROM ips WHERE id IN($1, $2, $3, $4, $5, $6);`, ips1, ips2, ips3, ips4, ips5, ips6)
assert.Nil(t, err)
_, err = conn.Exec(ctx, `DELETE FROM secrets WHERE id IN($1, $2, $3);`, secret1, secret2, secret3)
assert.Nil(t, err)
_, err = conn.Exec(ctx, `DELETE FROM systems WHERE id IN($1, $2, $3, $4, $5);`, validSystemID, invalidSystemID, validSystemID_invalidGroup, validSystemID_invalidSecret, validSystemID_invalidSecret_PastUpdated)
assert.Nil(t, err)
_, err = conn.Exec(ctx, `DELETE FROM groups WHERE id IN($1, $2);`, validGroupID, invalidGroupID)
assert.Nil(t, err)
}
// insert valid and invalid ip addresses into actual DB
dbURL, err := getDBURL()
assert.Nil(t, err)

ctx := context.Background()

conn, err := pgx.Connect(ctx, dbURL)
assert.Nil(t, err)
defer conn.Close(ctx)

tx, err := conn.Begin(context.Background())
assert.Nil(t, err)

var validGroupID, invalidGroupID, validSystemID, invalidSystemID, validSystemID_invalidGroup, validSystemID_invalidSecret, validSystemID_invalidSecret_PastUpdated, secret1, secret2, secret3, ips1, ips2, ips3, ips4, ips5, ips6, ipv6Valid, ipv6Invalid int
err = tx.QueryRow(ctx, `INSERT INTO groups (group_id) VALUES($1) RETURNING id;`, uuid.New().String()).Scan(&validGroupID)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO groups (group_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, uuid.New().String()).Scan(&invalidGroupID)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, validGroupID).Scan(&invalidSystemID)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, invalidGroupID).Scan(&validSystemID_invalidGroup)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID_invalidSecret)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID_invalidSecret_PastUpdated)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO secrets (system_id, updated_at) VALUES($1, NOW()) RETURNING id;`, validSystemID).Scan(&secret1)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO secrets (system_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, validSystemID_invalidSecret).Scan(&secret2)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO secrets (system_id, updated_at) VALUES($1, NOW() - INTERVAL '100 DAY') RETURNING id;`, validSystemID_invalidSecret_PastUpdated).Scan(&secret3)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES('127.0.0.1', $1) RETURNING id;`, validSystemID).Scan(&ips1)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES('127.0.0.2', $1) RETURNING id;`, validSystemID).Scan(&ips2)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.3', $1, NOW()) RETURNING id;`, invalidSystemID).Scan(&ips3)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.4', $1, NOW()) RETURNING id;`, validSystemID_invalidGroup).Scan(&ips4)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.5', $1, NOW()) RETURNING id;`, validSystemID_invalidSecret).Scan(&ips5)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.6', $1, NOW()) RETURNING id;`, validSystemID_invalidSecret_PastUpdated).Scan(&ips6)
assert.Nil(t, err)

testipv6Valid := "ecc3:92b5:56a4:84af:d086:4671:b091:1681"
testipv6Invalid := "b0b7:ed8f:348f:13d2:92b0:b018:9c57:4dc9"
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES($1, $2) RETURNING id;`, testipv6Valid, validSystemID).Scan(&ipv6Valid)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES($1, $2, NOW()) RETURNING id;`, testipv6Invalid, invalidSystemID).Scan(&ipv6Invalid)
assert.Nil(t, err)

// execute
addresses, ipv6Addresses, err := getValidIPAddresses(ctx, tx)

// verify
assert.Nil(t, err)
assert.Contains(t, addresses, "127.0.0.1/32")
assert.Contains(t, addresses, "127.0.0.2/32")
assert.NotContains(t, addresses, "127.0.0.3/32")
assert.NotContains(t, addresses, "127.0.0.4/32")
assert.NotContains(t, addresses, "127.0.0.5/32")
assert.NotContains(t, addresses, "127.0.0.6/32")
assert.Contains(t, ipv6Addresses, testipv6Valid+"/128")
assert.NotContains(t, ipv6Addresses, testipv6Invalid+"/128")

// cleanup
err = tx.Rollback(ctx)
assert.Nil(t, err)
}
Loading

0 comments on commit 3a0d368

Please sign in to comment.