Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BCDA-8680: Refactor to allow for ipv6 addresses #199

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/waf-sync-lambda-prod-deploy.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: WAFSync Lambda prod deploy
name: WAF Sync Lambda prod deploy

on:
workflow_dispatch:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/waf-sync-lambda-test-deploy.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: WAFSync Lambda test deploy
name: WAF Sync Lambda test deploy

on:
workflow_call:
Expand Down
1 change: 1 addition & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ services:
- DATABASE_URL=postgresql://postgres:toor@db:5432/bcda?sslmode=disable
- BCDA_SSAS_CLIENT_ID=fake-client-id
- BCDA_SSAS_SECRET=fake-secret
- ENV=local
- DEPLOYMENT_TARGET=local
- SSAS_ADMIN_SIGNING_KEY_PATH=../../../shared_files/ssas/admin_test_signing_key.pem
- SSAS_PUBLIC_SIGNING_KEY_PATH=../../../shared_files/ssas/public_test_signing_key.pem
Expand Down
27 changes: 1 addition & 26 deletions lambda/wafsync/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package main

import (
"fmt"
"os"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/wafv2"
"github.com/aws/aws-sdk-go/service/wafv2/wafv2iface"
log "github.com/sirupsen/logrus"
Expand All @@ -19,30 +17,7 @@ type Parameters struct {
Addresses []string
}

var createSession = func() (*session.Session, error) {
sess := session.Must(session.NewSession())

var err error
if isTesting {
sess, err = session.NewSessionWithOptions(session.Options{
Profile: "default",
Config: aws.Config{
Region: aws.String("us-east-1"),
S3ForcePathStyle: aws.Bool(true),
Endpoint: aws.String("http://localhost:4566"),
},
})
}
if err != nil {
return nil, err
}

return sess, nil
}

func fetchAndUpdateIpAddresses(waf wafv2iface.WAFV2API, ipAddresses []string) ([]string, error) {
ipSetName := fmt.Sprintf("bcda-%s-api-customers", os.Getenv("ENV"))

func fetchAndUpdateIpAddresses(waf wafv2iface.WAFV2API, ipSetName string, ipAddresses []string) ([]string, error) {
listParams := &wafv2.ListIPSetsInput{
Scope: aws.String("REGIONAL"),
}
Expand Down
2 changes: 1 addition & 1 deletion lambda/wafsync/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestNewLocalSession(t *testing.T) {
func TestFetchAndUpdateIpAddresses(t *testing.T) {
mock := &mockWAFV2Client{}

addresses, err := fetchAndUpdateIpAddresses(mock, []string{"127.0.0.1/32", "127.0.0.2/32"})
addresses, err := fetchAndUpdateIpAddresses(mock, "test-ip-set", []string{"127.0.0.1/32", "127.0.0.2/32"})

assert.Nil(t, err)
assert.Contains(t, addresses, "127.0.0.1/32")
Expand Down
25 changes: 16 additions & 9 deletions lambda/wafsync/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ type PgxConnection interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
Ping(context.Context) error
Prepare(context.Context, string, string) (*pgconn.StatementDescription, error)
Close(context.Context) error
}

func getValidIPAddresses(ctx context.Context, conn PgxConnection) ([]string, error) {
defer conn.Close(ctx)

func getValidIPAddresses(ctx context.Context, conn PgxConnection) ([]string, []string, error) {
query := `
SELECT DISTINCT ips.address FROM ips
WHERE deleted_at IS NULL
Expand All @@ -44,12 +40,13 @@ func getValidIPAddresses(ctx context.Context, conn PgxConnection) ([]string, err
rows, err := conn.Query(ctx, query)
if err != nil {
log.Errorf("Error running query: %+v", err)
return nil, err
return nil, nil, err
}

// count seems to only be used to log num of rows for debugging
count := 0
ipAddresses := []string{}
ipv6Addresses := []string{}
defer rows.Close()

for rows.Next() {
Expand All @@ -58,25 +55,35 @@ func getValidIPAddresses(ctx context.Context, conn PgxConnection) ([]string, err
err = rows.Scan(&ip)
if err != nil {
log.Errorf("Scan error: %+v", err)
return nil, err
return nil, nil, err
}

count += 1
if count%10000 == 0 {
log.Infof("Read %d rows", count)
}

ipAddresses = append(ipAddresses, ip.String()+"/32")
// check if ip address is IPv4 or IPv6
if ip.To4() != nil {
ipAddresses = append(ipAddresses, ip.String()+"/32")
} else {
ipv6Addresses = append(ipv6Addresses, ip.String()+"/128")
}

}

log.WithField("num_rows_scanned", count).Info("Successfully retrieved valid IP addresses")

return ipAddresses, nil
return ipAddresses, ipv6Addresses, nil
}

func getDBURL() (string, error) {
env := conf.GetEnv("ENV")

if env == "local" {
return conf.GetEnv("DATABASE_URL"), nil
}

bcdaSession, err := bcdaaws.NewSession("", os.Getenv("LOCAL_STACK_ENDPOINT"))
if err != nil {
return "", err
Expand Down
161 changes: 88 additions & 73 deletions lambda/wafsync/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package main

import (
"context"
"os"
"slices"
"errors"
"testing"

"github.com/google/uuid"
Expand All @@ -20,82 +19,98 @@ func TestGetValidIPAddresses(t *testing.T) {
defer mock.Close(ctx)

// the column ips.address returns type net.IP which needs to be handled here like a byte array
rows := mock.NewRows([]string{"address"}).AddRow([]byte("127.0.0.1"))
// due to pgxmock being unhappy with any other approach :(
rows := mock.NewRows([]string{"address"}).AddRow([]byte("127.0.0.1")).AddRow([]byte("1:2:3:4:5:6:7:8"))
mock.ExpectQuery("^SELECT DISTINCT ips.address FROM ips WHERE (.+)$").WillReturnRows(rows)

addresses, err := getValidIPAddresses(ctx, mock)
_, ipv6Addresses, err := getValidIPAddresses(ctx, mock)
assert.Nil(t, err)
// verifying on length of return as byte array from above gets muddled
assert.Len(t, addresses, 1)
assert.Len(t, ipv6Addresses, 2)
}

func TestGetValidIPAddressesFailure(t *testing.T) {
ctx := context.Background()
mock, err := pgxmock.NewConn()
assert.Nil(t, err)
defer mock.Close(ctx)

mock.ExpectQuery("^SELECT DISTINCT ips.address FROM ips WHERE (.+)$").WillReturnError(errors.New("test error"))

_, _, err = getValidIPAddresses(ctx, mock)
assert.ErrorContains(t, err, "test error")
}

func TestGetValidIPAddresses_Integration(t *testing.T) {
// only run actual DB testing in lower envs
if slices.Contains([]string{"local", "dev", "test"}, os.Getenv("ENV")) {
// insert valid and invalid ip addresses into actual DB
dbURL, err := getDBURL()
assert.Nil(t, err)

ctx := context.Background()

conn, err := pgx.Connect(ctx, dbURL)
assert.Nil(t, err)
defer conn.Close(ctx)

var validGroupID, invalidGroupID, validSystemID, invalidSystemID, validSystemID_invalidGroup, validSystemID_invalidSecret, validSystemID_invalidSecret_PastUpdated, secret1, secret2, secret3, ips1, ips2, ips3, ips4, ips5, ips6 int
err = conn.QueryRow(ctx, `INSERT INTO groups (group_id) VALUES($1) RETURNING id;`, uuid.New().String()).Scan(&validGroupID)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO groups (group_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, uuid.New().String()).Scan(&invalidGroupID)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, validGroupID).Scan(&invalidSystemID)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, invalidGroupID).Scan(&validSystemID_invalidGroup)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID_invalidSecret)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID_invalidSecret_PastUpdated)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO secrets (system_id, updated_at) VALUES($1, NOW()) RETURNING id;`, validSystemID).Scan(&secret1)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO secrets (system_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, validSystemID_invalidSecret).Scan(&secret2)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO secrets (system_id, updated_at) VALUES($1, NOW() - INTERVAL '100 DAY') RETURNING id;`, validSystemID_invalidSecret_PastUpdated).Scan(&secret3)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES('127.0.0.1', $1) RETURNING id;`, validSystemID).Scan(&ips1)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES('127.0.0.2', $1) RETURNING id;`, validSystemID).Scan(&ips2)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.3', $1, NOW()) RETURNING id;`, invalidSystemID).Scan(&ips3)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.4', $1, NOW()) RETURNING id;`, validSystemID_invalidGroup).Scan(&ips4)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.5', $1, NOW()) RETURNING id;`, validSystemID_invalidSecret).Scan(&ips5)
assert.Nil(t, err)
err = conn.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.6', $1, NOW()) RETURNING id;`, validSystemID_invalidSecret_PastUpdated).Scan(&ips6)
assert.Nil(t, err)

// execute
addresses, err := getValidIPAddresses(ctx, conn)

// verify
assert.Nil(t, err)
assert.Contains(t, addresses, "127.0.0.1/32")
assert.Contains(t, addresses, "127.0.0.2/32")
assert.NotContains(t, addresses, "127.0.0.3/32")
assert.NotContains(t, addresses, "127.0.0.4/32")
assert.NotContains(t, addresses, "127.0.0.5/32")
assert.NotContains(t, addresses, "127.0.0.6/32")

// cleanup
_, err = conn.Exec(ctx, `DELETE FROM ips WHERE id IN($1, $2, $3, $4, $5, $6);`, ips1, ips2, ips3, ips4, ips5, ips6)
assert.Nil(t, err)
_, err = conn.Exec(ctx, `DELETE FROM secrets WHERE id IN($1, $2, $3);`, secret1, secret2, secret3)
assert.Nil(t, err)
_, err = conn.Exec(ctx, `DELETE FROM systems WHERE id IN($1, $2, $3, $4, $5);`, validSystemID, invalidSystemID, validSystemID_invalidGroup, validSystemID_invalidSecret, validSystemID_invalidSecret_PastUpdated)
assert.Nil(t, err)
_, err = conn.Exec(ctx, `DELETE FROM groups WHERE id IN($1, $2);`, validGroupID, invalidGroupID)
assert.Nil(t, err)
}
// insert valid and invalid ip addresses into actual DB
dbURL, err := getDBURL()
assert.Nil(t, err)

ctx := context.Background()

conn, err := pgx.Connect(ctx, dbURL)
assert.Nil(t, err)
defer conn.Close(ctx)

tx, err := conn.Begin(context.Background())
Copy link
Collaborator Author

@carlpartridge carlpartridge Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight refactoring of this test to use transaction to avoid data issues.

assert.Nil(t, err)

var validGroupID, invalidGroupID, validSystemID, invalidSystemID, validSystemID_invalidGroup, validSystemID_invalidSecret, validSystemID_invalidSecret_PastUpdated, secret1, secret2, secret3, ips1, ips2, ips3, ips4, ips5, ips6, ipv6Valid, ipv6Invalid int
err = tx.QueryRow(ctx, `INSERT INTO groups (group_id) VALUES($1) RETURNING id;`, uuid.New().String()).Scan(&validGroupID)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO groups (group_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, uuid.New().String()).Scan(&invalidGroupID)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, validGroupID).Scan(&invalidSystemID)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, invalidGroupID).Scan(&validSystemID_invalidGroup)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID_invalidSecret)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO systems (g_id) VALUES($1) RETURNING id;`, validGroupID).Scan(&validSystemID_invalidSecret_PastUpdated)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO secrets (system_id, updated_at) VALUES($1, NOW()) RETURNING id;`, validSystemID).Scan(&secret1)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO secrets (system_id, deleted_at) VALUES($1, NOW()) RETURNING id;`, validSystemID_invalidSecret).Scan(&secret2)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO secrets (system_id, updated_at) VALUES($1, NOW() - INTERVAL '100 DAY') RETURNING id;`, validSystemID_invalidSecret_PastUpdated).Scan(&secret3)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES('127.0.0.1', $1) RETURNING id;`, validSystemID).Scan(&ips1)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES('127.0.0.2', $1) RETURNING id;`, validSystemID).Scan(&ips2)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.3', $1, NOW()) RETURNING id;`, invalidSystemID).Scan(&ips3)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.4', $1, NOW()) RETURNING id;`, validSystemID_invalidGroup).Scan(&ips4)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.5', $1, NOW()) RETURNING id;`, validSystemID_invalidSecret).Scan(&ips5)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES('127.0.0.6', $1, NOW()) RETURNING id;`, validSystemID_invalidSecret_PastUpdated).Scan(&ips6)
assert.Nil(t, err)

testipv6Valid := "ecc3:92b5:56a4:84af:d086:4671:b091:1681"
testipv6Invalid := "b0b7:ed8f:348f:13d2:92b0:b018:9c57:4dc9"
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id) VALUES($1, $2) RETURNING id;`, testipv6Valid, validSystemID).Scan(&ipv6Valid)
assert.Nil(t, err)
err = tx.QueryRow(ctx, `INSERT INTO ips (address, system_id, deleted_at) VALUES($1, $2, NOW()) RETURNING id;`, testipv6Invalid, invalidSystemID).Scan(&ipv6Invalid)
assert.Nil(t, err)

// execute
addresses, ipv6Addresses, err := getValidIPAddresses(ctx, tx)

// verify
assert.Nil(t, err)
assert.Contains(t, addresses, "127.0.0.1/32")
assert.Contains(t, addresses, "127.0.0.2/32")
assert.NotContains(t, addresses, "127.0.0.3/32")
assert.NotContains(t, addresses, "127.0.0.4/32")
assert.NotContains(t, addresses, "127.0.0.5/32")
assert.NotContains(t, addresses, "127.0.0.6/32")
assert.Contains(t, ipv6Addresses, testipv6Valid+"/128")
assert.NotContains(t, ipv6Addresses, testipv6Invalid+"/128")

// cleanup
err = tx.Rollback(ctx)
assert.Nil(t, err)
}
Loading
Loading