improved prefix query validation

This commit is contained in:
Annika Hannig 2022-01-13 15:59:57 +01:00
parent c0546be4cc
commit 2cf3ad6b59
7 changed files with 239 additions and 23 deletions

View File

@ -4,7 +4,7 @@
# @author : annika # @author : annika
# @file : init # @file : init
# @created : Tuesday Jan 11, 2022 15:35:20 CET # @created : Tuesday Jan 11, 2022 15:35:20 CET
# @description : Initialize the database # @description : Start `psql` as database shell
###################################################################### ######################################################################
if [ -z $PSQL ]; then if [ -z $PSQL ]; then
@ -55,5 +55,5 @@ if [ $OPT_TESTING -eq 1 ]; then
export PGDATABASE=$NAME export PGDATABASE=$NAME
fi fi
psql exec psql

View File

@ -25,11 +25,6 @@ func (s *Server) apiLookupPrefixGlobal(
return nil, err return nil, err
} }
q, err = validatePrefixQuery(q)
if err != nil {
return nil, err
}
// Check what we want to query // Check what we want to query
// Prefix -> fetch prefix // Prefix -> fetch prefix
// _ -> fetch neighbors and routes // _ -> fetch neighbors and routes
@ -47,6 +42,10 @@ func (s *Server) apiLookupPrefixGlobal(
// Perform query // Perform query
var routes api.LookupRoutes var routes api.LookupRoutes
if lookupPrefix { if lookupPrefix {
q, err = validatePrefixQuery(q)
if err != nil {
return nil, err
}
routes, err = s.routesStore.LookupPrefix(ctx, q) routes, err = s.routesStore.LookupPrefix(ctx, q)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -1,11 +1,24 @@
package http package http
import ( import (
"errors"
"fmt" "fmt"
"strings"
"net/http" "net/http"
) )
var (
// ErrQueryTooShort will be returned when the query
// is less than 2 characters.
ErrQueryTooShort = errors.New("query too short")
// ErrQueryIncomplete will be returned when the
// prefix query lacks a : or .
ErrQueryIncomplete = errors.New(
"prefix query must contain at least on '.' or ':'")
)
// Helper: Validate source Id // Helper: Validate source Id
func validateSourceID(id string) (string, error) { func validateSourceID(id string) (string, error) {
if len(id) > 42 { if len(id) > 42 {
@ -34,11 +47,15 @@ func validateQueryString(req *http.Request, key string) (string, error) {
return value, nil return value, nil
} }
// Helper: Validate prefix query // Helper: Validate prefix query. It should contain
// at least one dot or :
func validatePrefixQuery(value string) (string, error) { func validatePrefixQuery(value string) (string, error) {
// We should at least provide 2 chars // We should at least provide 2 chars
if len(value) < 2 { if len(value) < 2 {
return "", fmt.Errorf("query too short") return "", ErrQueryTooShort
}
if !strings.Contains(value, ":") && !strings.Contains(value, ".") {
return "", ErrQueryIncomplete
} }
return value, nil return value, nil
} }

View File

@ -32,21 +32,27 @@ func (b *NeighborsBackend) SetNeighbors(
sourceID string, sourceID string,
neighbors api.Neighbors, neighbors api.Neighbors,
) error { ) error {
// Clear current neighbors
now := time.Now().UTC()
for _, n := range neighbors { for _, n := range neighbors {
if err := b.persistNeighbor(ctx, sourceID, n); err != nil { if err := b.persist(ctx, sourceID, n, now); err != nil {
return err return err
} }
} }
// Remove old neighbors
if err := b.deleteStale(ctx, sourceID, now); err != nil {
return err
}
return nil return nil
} }
// Private persistNeighbor saves a neighbor to the database // Private persist saves a neighbor to the database
func (b *NeighborsBackend) persistNeighbor( func (b *NeighborsBackend) persist(
ctx context.Context, ctx context.Context,
sourceID string, sourceID string,
neighbor *api.Neighbor, neighbor *api.Neighbor,
now time.Time,
) error { ) error {
now := time.Now().UTC()
qry := ` qry := `
INSERT INTO neighbors ( INSERT INTO neighbors (
id, rs_id, neighbor, updated_at id, rs_id, neighbor, updated_at
@ -59,6 +65,22 @@ func (b *NeighborsBackend) persistNeighbor(
return err return err
} }
// Private deleteStale removes all neighbors not inserted or
// updated at a specific time.
func (b *NeighborsBackend) deleteStale(
ctx context.Context,
sourceID string,
t time.Time,
) error {
qry := `
DELETE FROM neighbors
WHERE rs_id = $1
AND updated_at <> $2
`
_, err := b.pool.Exec(ctx, qry, sourceID, t)
return err
}
// Private queryNeighborsAt selects all neighbors // Private queryNeighborsAt selects all neighbors
// for a given sourceID // for a given sourceID
func (b *NeighborsBackend) queryNeighborsAt( func (b *NeighborsBackend) queryNeighborsAt(

View File

@ -3,6 +3,7 @@ package postgres
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/alice-lg/alice-lg/pkg/api" "github.com/alice-lg/alice-lg/pkg/api"
) )
@ -14,24 +15,25 @@ func TestPersistNeighborLookup(t *testing.T) {
ID: "n2342", ID: "n2342",
Address: "test123", Address: "test123",
} }
if err := b.persistNeighbor(context.Background(), "rs1", n); err != nil { now := time.Now().UTC()
if err := b.persist(context.Background(), "rs1", n, now); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// make an update // make an update
n.Address = "test234" n.Address = "test234"
if err := b.persistNeighbor(context.Background(), "rs1", n); err != nil { if err := b.persist(context.Background(), "rs1", n, now); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Add a second // Add a second
n.ID = "foo23" n.ID = "foo23"
if err := b.persistNeighbor(context.Background(), "rs1", n); err != nil { if err := b.persist(context.Background(), "rs1", n, now); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Add to different rs // Add to different rs
if err := b.persistNeighbor(context.Background(), "rs2", n); err != nil { if err := b.persist(context.Background(), "rs2", n, now); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -0,0 +1,173 @@
package postgres
import (
"context"
"fmt"
"strings"
"time"
"github.com/alice-lg/alice-lg/pkg/api"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
// RoutesBackend implements a postgres store for routes.
type RoutesBackend struct {
pool *pgxpool.Pool
}
// NewRoutesBackend creates a new instance with a postgres
// connection pool.
func NewRoutesBackend(pool *pgxpool.Pool) *RoutesBackend {
return &RoutesBackend{
pool: pool,
}
}
// SetRoutes implements the RoutesStoreBackend interface
// function for setting all routes of a source identified
// by ID.
func (b *RoutesBackend) SetRoutes(
ctx context.Context,
sourceID string,
routes api.LookupRoutes,
) error {
now := time.Now().UTC()
for _, r := range routes {
if err := b.persist(ctx, sourceID, r, now); err != nil {
return err
}
}
if err := b.deleteStale(ctx, sourceID, now); err != nil {
return err
}
return nil
}
// Private persist route in database
func (b *RoutesBackend) persist(
ctx context.Context,
sourceID string,
route *api.LookupRoute,
now time.Time,
) error {
qry := `
INSERT INTO routes (
id,
rs_id,
neighbor_id,
network,
route,
updated_at
) VALUES (
$1, $2, $3, $4, $5, $6
)
ON CONFLICT ON CONSTRAINT routes_pkey DO UPDATE
SET route = EXCLUDED.route,
network = EXCLUDED.network,
neighbor_id = EXCLUDED.neighbor_id,
updated_at = EXCLUDED.updated_at
`
_, err := b.pool.Exec(
ctx,
qry,
route.Route.ID,
sourceID,
route.Neighbor.ID,
route.Route.Network,
route,
now)
return err
}
// Private deleteStale removes all routes not inserted or
// updated at a specific time.
func (b *RoutesBackend) deleteStale(
ctx context.Context,
sourceID string,
t time.Time,
) error {
qry := `
DELETE FROM routes
WHERE rs_id = $1
AND updated_at <> $2
`
_, err := b.pool.Exec(ctx, qry, sourceID, t)
return err
}
// Private queryCountByState will query routes and filter
// by state
func (b *RoutesBackend) queryCountByState(
ctx context.Context,
sourceID string,
state string,
) pgx.Row {
qry := `SELECT COUNT(1) FROM routes
WHERE rs_id = $1 AND route->'state' = $2`
return b.pool.QueryRow(ctx, qry, sourceID, state)
}
// CountRoutesAt returns the number of filtered and imported
// routes and implements the RoutesStoreBackend interface.
func (b *RoutesBackend) CountRoutesAt(
ctx context.Context,
sourceID string,
) (uint, uint, error) {
var (
imported uint
filtered uint
)
err := b.queryCountByState(ctx, sourceID, api.RouteStateFiltered).
Scan(&filtered)
if err != nil {
return 0, 0, err
}
err = b.queryCountByState(ctx, sourceID, api.RouteStateImported).
Scan(&imported)
if err != nil {
return 0, 0, err
}
return imported, filtered, nil
}
// FindByNeighbors will return the prefixes for a
// list of neighbors identified by ID.
func (b *RoutesBackend) FindByNeighbors(
ctx context.Context,
neighborIDs []interface{},
) (api.LookupRoutes, error) {
vars := make([]string, 0, len(neighborIDs))
for n := range neighborIDs {
vars = append(vars, fmt.Sprintf("$%d", n+1))
}
listQry := strings.Join(vars, ",")
qry := `
SELECT route
FROM routes
WHERE neighbor_id IN (` + listQry + `)`
rows, err := b.pool.Query(ctx, qry, neighborIDs...)
if err != nil {
return nil, err
}
cmd := rows.CommandTag()
results := make(api.LookupRoutes, 0, cmd.RowsAffected())
for rows.Next() {
route := &api.LookupRoute{}
if err := rows.Scan(&route); err != nil {
return nil, err
}
results = append(results, route)
}
return results, nil
}
// FindByPrefix will return the prefixes matching a pattern
func (b *RoutesBackend) FindByPrefix(
ctx context.Context,
prefix string,
) (api.LookupRoutes, error) {
// We are searching route.Network
return nil, nil
}

View File

@ -32,6 +32,8 @@ CREATE TABLE neighbors (
CREATE INDEX idx_neighbors_rs_id CREATE INDEX idx_neighbors_rs_id
ON neighbors USING HASH (rs_id); ON neighbors USING HASH (rs_id);
CREATE INDEX idx_neighbors_updated_at
ON neighbors ( updated_at );
-- Routes -- Routes
CREATE TABLE routes ( CREATE TABLE routes (
@ -40,7 +42,7 @@ CREATE TABLE routes (
neighbor_id VARCHAR(255) NOT NULL, neighbor_id VARCHAR(255) NOT NULL,
-- Indexed attributes -- Indexed attributes
network cidr NOT NULL, network VARCHAR(50) NOT NULL,
-- JSON serialized route -- JSON serialized route
route jsonb NOT NULL, route jsonb NOT NULL,
@ -55,7 +57,8 @@ CREATE TABLE routes (
); );
CREATE INDEX idx_routes_network ON routes ( network ); CREATE INDEX idx_routes_network ON routes ( network );
CREATE INDEX idx_neighbor_id ON routes ( neighbor_id );
CREATE INDEX idx_routes_updated_at ON routes ( updated_at );
-- The meta table stores information about the schema -- The meta table stores information about the schema
-- like when it was migrated and the current revision. -- like when it was migrated and the current revision.