keep replacement within a tx

This commit is contained in:
Annika Hannig 2022-03-14 14:03:43 +01:00
parent bd79b4dafc
commit 3aa76fc45f

View File

@ -32,15 +32,26 @@ func (b *RoutesBackend) SetRoutes(
sourceID string,
routes api.LookupRoutes,
) error {
// Acquire connection
now := time.Now().UTC()
tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{
IsoLevel: pgx.ReadCommitted,
})
if err != nil {
return err
}
defer tx.Rollback(ctx)
if err := b.clear(ctx, tx, sourceID); err != nil {
return err
}
// persist all routes
for _, r := range routes {
if err := b.persist(ctx, sourceID, r, now); err != nil {
if err := b.persist(ctx, tx, sourceID, r, now); err != nil {
return err
}
}
if err := b.deleteStale(ctx, sourceID, now); err != nil {
if err := tx.Commit(ctx); err != nil {
return err
}
return nil
@ -49,6 +60,7 @@ func (b *RoutesBackend) SetRoutes(
// Private persist route in database
func (b *RoutesBackend) persist(
ctx context.Context,
tx pgx.Tx,
sourceID string,
route *api.LookupRoute,
now time.Time,
@ -64,13 +76,8 @@ func (b *RoutesBackend) persist(
) VALUES (
$1, $2, $3, $4, $5, $6
)
ON CONFLICT ON CONSTRAINT routes_pkey DO UPDATE
SET route = EXCLUDED.route,
network = EXCLUDED.network,
neighbor_id = EXCLUDED.neighbor_id,
updated_at = EXCLUDED.updated_at
`
_, err := b.pool.Exec(
_, err := tx.Exec(
ctx,
qry,
route.Route.ID,
@ -82,19 +89,16 @@ func (b *RoutesBackend) persist(
return err
}
// Private deleteStale removes all routes not inserted or
// updated at a specific time.
func (b *RoutesBackend) deleteStale(
// Private clear removes all routes.
func (b *RoutesBackend) clear(
ctx context.Context,
tx pgx.Tx,
sourceID string,
t time.Time,
) error {
qry := `
DELETE FROM routes
WHERE rs_id = $1
AND updated_at <> $2
DELETE FROM routes WHERE rs_id = $1
`
_, err := b.pool.Exec(ctx, qry, sourceID, t)
_, err := tx.Exec(ctx, qry, sourceID)
return err
}
@ -102,12 +106,14 @@ func (b *RoutesBackend) deleteStale(
// by state
func (b *RoutesBackend) queryCountByState(
ctx context.Context,
tx pgx.Tx,
sourceID string,
state string,
) pgx.Row {
qry := `SELECT COUNT(1) FROM routes
WHERE rs_id = $1 AND route -> 'state' = $2`
return b.pool.QueryRow(ctx, qry, sourceID, "\""+state+"\"")
return tx.QueryRow(ctx, qry, sourceID, "\""+state+"\"")
}
// CountRoutesAt returns the number of filtered and imported
@ -116,16 +122,24 @@ func (b *RoutesBackend) CountRoutesAt(
ctx context.Context,
sourceID string,
) (uint, uint, error) {
tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{
IsoLevel: pgx.ReadCommitted,
})
if err != nil {
return 0, 0, err
}
defer tx.Rollback(ctx)
var (
imported uint
filtered uint
)
err := b.queryCountByState(ctx, sourceID, api.RouteStateFiltered).
err = b.queryCountByState(ctx, tx, sourceID, api.RouteStateFiltered).
Scan(&filtered)
if err != nil {
return 0, 0, err
}
err = b.queryCountByState(ctx, sourceID, api.RouteStateImported).
err = b.queryCountByState(ctx, tx, sourceID, api.RouteStateImported).
Scan(&imported)
if err != nil {
return 0, 0, err
@ -139,6 +153,14 @@ func (b *RoutesBackend) FindByNeighbors(
ctx context.Context,
neighborIDs []string,
) (api.LookupRoutes, error) {
tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{
IsoLevel: pgx.ReadCommitted,
})
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
vals := make([]interface{}, len(neighborIDs))
for i := range neighborIDs {
vals[i] = neighborIDs[i]
@ -152,10 +174,11 @@ func (b *RoutesBackend) FindByNeighbors(
SELECT route FROM routes
WHERE neighbor_id IN (` + listQry + `)`
rows, err := b.pool.Query(ctx, qry, vals...)
rows, err := tx.Query(ctx, qry, vals...)
if err != nil {
return nil, err
}
return fetchRoutes(rows)
}
@ -164,12 +187,19 @@ func (b *RoutesBackend) FindByPrefix(
ctx context.Context,
prefix string,
) (api.LookupRoutes, error) {
tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{
IsoLevel: pgx.ReadCommitted,
})
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
// We are searching route.Network
qry := `
SELECT route FROM routes
WHERE network ILIKE $1
`
rows, err := b.pool.Query(ctx, qry, prefix+"%")
rows, err := tx.Query(ctx, qry, prefix+"%")
if err != nil {
return nil, err
}