use pointers as identifier.

This commit is contained in:
Annika Hannig 2024-01-16 17:07:09 +01:00
parent 852d2d7a6d
commit 98847ba0cb
5 changed files with 75 additions and 165 deletions

View File

@ -2,7 +2,9 @@ package pools
import ( import (
"math" "math"
"reflect"
"sync" "sync"
"unsafe"
"github.com/alice-lg/alice-lg/pkg/api" "github.com/alice-lg/alice-lg/pkg/api"
) )
@ -12,7 +14,6 @@ import (
// communities, use the ExtCommunityPool. // communities, use the ExtCommunityPool.
type CommunitiesPool struct { type CommunitiesPool struct {
root *Node[int, api.Community] root *Node[int, api.Community]
counter uint64
sync.RWMutex sync.RWMutex
} }
@ -23,32 +24,22 @@ func NewCommunitiesPool() *CommunitiesPool {
} }
} }
// AcquireGid acquires a single bgp community with gid // Acquire a single bgp community
func (p *CommunitiesPool) AcquireGid(c api.Community) (api.Community, uint64) { func (p *CommunitiesPool) Acquire(c api.Community) api.Community {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if len(c) == 0 { if len(c) == 0 {
return p.root.value, p.root.gid return p.root.value
} }
v, gid := p.root.traverse(p.counter+1, c, c) return p.root.traverse(c, c)
if gid > p.counter {
p.counter = gid
}
return v, gid
}
// Acquire a single bgp community without gid
func (p *CommunitiesPool) Acquire(c api.Community) api.Community {
v, _ := p.AcquireGid(c)
return v
} }
// Read a single bgp community // Read a single bgp community
func (p *CommunitiesPool) Read(c api.Community) (api.Community, uint64) { func (p *CommunitiesPool) Read(c api.Community) api.Community {
p.RLock() p.RLock()
defer p.RUnlock() defer p.RUnlock()
if len(c) == 0 { if len(c) == 0 {
return p.root.value, p.root.gid return p.root.value
} }
return p.root.read(c) return p.root.read(c)
} }
@ -57,8 +48,7 @@ func (p *CommunitiesPool) Read(c api.Community) (api.Community, uint64) {
// (Large and default. The ext communities representation right now // (Large and default. The ext communities representation right now
// makes problems and need to be fixed. TODO.) // makes problems and need to be fixed. TODO.)
type CommunitiesSetPool struct { type CommunitiesSetPool struct {
root *Node[uint64, []api.Community] root *Node[unsafe.Pointer, []api.Community]
counter uint64
sync.Mutex sync.Mutex
} }
@ -66,47 +56,34 @@ type CommunitiesSetPool struct {
// of BGP communities. // of BGP communities.
func NewCommunitiesSetPool() *CommunitiesSetPool { func NewCommunitiesSetPool() *CommunitiesSetPool {
return &CommunitiesSetPool{ return &CommunitiesSetPool{
root: NewNode[uint64, []api.Community]([]api.Community{}), root: NewNode[unsafe.Pointer, []api.Community]([]api.Community{}),
} }
} }
// AcquireGid acquires a list of bgp communities and returns a gid
func (p *CommunitiesSetPool) AcquireGid(
communities []api.Community,
) ([]api.Community, uint64) {
p.Lock()
defer p.Unlock()
// Make identification list by using the pointer address
// of the deduplicated community as ID
ids := make([]uint64, len(communities))
set := make([]api.Community, len(communities))
for i, comm := range communities {
ptr, gid := Communities.AcquireGid(comm)
ids[i] = gid
set[i] = ptr
}
if len(ids) == 0 {
return p.root.value, p.root.gid
}
v, id := p.root.traverse(p.counter+1, set, ids)
if id > p.counter {
p.counter = id
}
return v, id
}
// Acquire a list of bgp communities // Acquire a list of bgp communities
func (p *CommunitiesSetPool) Acquire( func (p *CommunitiesSetPool) Acquire(
communities []api.Community, communities []api.Community,
) []api.Community { ) []api.Community {
v, _ := p.AcquireGid(communities) p.Lock()
return v defer p.Unlock()
// Make identification list by using the pointer address
// of the deduplicated community as ID
ids := make([]unsafe.Pointer, len(communities))
set := make([]api.Community, len(communities))
for i, comm := range communities {
ptr := Communities.Acquire(comm)
ids[i] = reflect.ValueOf(ptr).UnsafePointer()
set[i] = ptr
}
if len(ids) == 0 {
return p.root.value
}
return p.root.traverse(set, ids)
} }
// ExtCommunitiesSetPool is for deduplicating a list of ext. BGP communities // ExtCommunitiesSetPool is for deduplicating a list of ext. BGP communities
type ExtCommunitiesSetPool struct { type ExtCommunitiesSetPool struct {
root *Node[uint64, []api.ExtCommunity] root *Node[unsafe.Pointer, []api.ExtCommunity]
counter uint64
sync.Mutex sync.Mutex
} }
@ -114,7 +91,7 @@ type ExtCommunitiesSetPool struct {
// of BGP communities. // of BGP communities.
func NewExtCommunitiesSetPool() *ExtCommunitiesSetPool { func NewExtCommunitiesSetPool() *ExtCommunitiesSetPool {
return &ExtCommunitiesSetPool{ return &ExtCommunitiesSetPool{
root: NewNode[uint64, []api.ExtCommunity]([]api.ExtCommunity{}), root: NewNode[unsafe.Pointer, []api.ExtCommunity]([]api.ExtCommunity{}),
} }
} }
@ -126,37 +103,25 @@ func extPrefixToInt(s string) int {
return v return v
} }
// AcquireGid acquires a list of ext bgp communities // Acquire a list of ext bgp communities
func (p *ExtCommunitiesSetPool) AcquireGid( func (p *ExtCommunitiesSetPool) Acquire(
communities []api.ExtCommunity, communities []api.ExtCommunity,
) ([]api.ExtCommunity, uint64) { ) []api.ExtCommunity {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
// Make identification list // Make identification list
ids := make([]uint64, len(communities)) ids := make([]unsafe.Pointer, len(communities))
for i, comm := range communities { for i, comm := range communities {
r := extPrefixToInt(comm[0].(string)) r := extPrefixToInt(comm[0].(string))
icomm := []int{r, comm[1].(int), comm[2].(int)} icomm := []int{r, comm[1].(int), comm[2].(int)}
// get community identifier // get community identifier
_, gid := ExtCommunities.AcquireGid(icomm) ptr := ExtCommunities.Acquire(icomm)
ids[i] = gid ids[i] = reflect.ValueOf(ptr).UnsafePointer()
} }
if len(ids) == 0 { if len(ids) == 0 {
return p.root.value, p.root.gid return p.root.value
} }
v, id := p.root.traverse(p.counter+1, communities, ids) return p.root.traverse(communities, ids)
if id > p.counter {
p.counter = id
}
return v, id
}
// Acquire a list of ext bgp communities
func (p *ExtCommunitiesSetPool) Acquire(
communities []api.ExtCommunity,
) []api.ExtCommunity {
v, _ := p.AcquireGid(communities)
return v
} }

View File

@ -9,16 +9,16 @@ import (
"github.com/alice-lg/alice-lg/pkg/api" "github.com/alice-lg/alice-lg/pkg/api"
) )
func TestAcquireGidCommunity(t *testing.T) { func TestAcquireCommunity(t *testing.T) {
c1 := api.Community{2342, 5, 1} c1 := api.Community{2342, 5, 1}
c2 := api.Community{2342, 5, 1} c2 := api.Community{2342, 5, 1}
c3 := api.Community{2342, 5} c3 := api.Community{2342, 5}
p := NewCommunitiesPool() p := NewCommunitiesPool()
pc1, gid1 := p.AcquireGid(c1) pc1 := p.Acquire(c1)
pc2, gid2 := p.AcquireGid(c2) pc2 := p.Acquire(c2)
pc3, gid3 := p.AcquireGid(c3) pc3 := p.Acquire(c3)
if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) { if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) {
t.Error("expected c1 !== c2") t.Error("expected c1 !== c2")
@ -28,19 +28,10 @@ func TestAcquireGidCommunity(t *testing.T) {
t.Error("expected pc1 == pc2") t.Error("expected pc1 == pc2")
} }
if gid1 != gid2 {
t.Error("expected gid1 == gid2")
}
if gid1 == gid3 {
t.Error("expected gid1 != gid3")
}
fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3) fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3)
fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3) fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3)
log.Println(c3, pc3) log.Println(c3, pc3)
log.Println(gid1, gid2, gid3)
} }
func TestCommunityRead(t *testing.T) { func TestCommunityRead(t *testing.T) {
@ -50,29 +41,22 @@ func TestCommunityRead(t *testing.T) {
p := NewCommunitiesPool() p := NewCommunitiesPool()
pc1, gid1 := p.AcquireGid(c1) pc1 := p.Acquire(c1)
pc2, gid2 := p.Read(c2) pc2 := p.Read(c2)
pc3, gid3 := p.Read(c3) pc3 := p.Read(c3)
fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3) fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3)
fmt.Printf("gid1: %d, gid2: %d, gid3: %d\n", gid1, gid2, gid3)
if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) { if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) {
t.Error("expected pc1 == pc2") t.Error("expected pc1 == pc2")
} }
if gid1 != gid2 {
t.Error("expected gid1 == gid2")
}
if pc3 != nil { if pc3 != nil {
t.Error("expected pc3 == nil, got", pc3) t.Error("expected pc3 == nil, got", pc3)
} }
if gid3 != 0 {
t.Error("expected gid3 == 0, got", gid3)
}
} }
func TestAcquireGidCommunitiesSets(t *testing.T) { func TestAcquireCommunitiesSets(t *testing.T) {
c1 := []api.Community{ c1 := []api.Community{
{2342, 5, 1}, {2342, 5, 1},
{2342, 5, 2}, {2342, 5, 2},
@ -91,9 +75,9 @@ func TestAcquireGidCommunitiesSets(t *testing.T) {
p := NewCommunitiesSetPool() p := NewCommunitiesSetPool()
pc1, gid1 := p.AcquireGid(c1) pc1 := p.Acquire(c1)
pc2, gid2 := p.AcquireGid(c2) pc2 := p.Acquire(c2)
pc3, gid3 := p.AcquireGid(c3) pc3 := p.Acquire(c3)
if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) { if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) {
t.Error("expected c1 !== c2") t.Error("expected c1 !== c2")
@ -102,14 +86,9 @@ func TestAcquireGidCommunitiesSets(t *testing.T) {
if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) { if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) {
t.Error("expected pc1 == pc2") t.Error("expected pc1 == pc2")
} }
if gid1 != gid2 {
t.Error("expected gid1 == gid2")
}
fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3) fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3)
fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3) fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3)
t.Logf("gid1: %d, gid2: %d, gid3: %d\n", gid1, gid2, gid3)
} }
func TestSetCommunityIdentity(t *testing.T) { func TestSetCommunityIdentity(t *testing.T) {
@ -119,11 +98,10 @@ func TestSetCommunityIdentity(t *testing.T) {
{2341, 1, 1}, {2341, 1, 1},
} }
pset, gid1 := CommunitiesSets.AcquireGid(set) pset := CommunitiesSets.Acquire(set)
pval, gid2 := Communities.AcquireGid(api.Community{2341, 6, 2}) pval := Communities.Acquire(api.Community{2341, 6, 2})
fmt.Printf("set: %p, pset[1]: %p, pval: %p\n", set, pset[1], pval) fmt.Printf("set: %p, pset[1]: %p, pval: %p\n", set, pset[1], pval)
fmt.Printf("gid1: %d, gid2: %d\n", gid1, gid2)
p1 := reflect.ValueOf(pset[1]).UnsafePointer() p1 := reflect.ValueOf(pset[1]).UnsafePointer()
p2 := reflect.ValueOf(pval).UnsafePointer() p2 := reflect.ValueOf(pval).UnsafePointer()
@ -133,7 +111,7 @@ func TestSetCommunityIdentity(t *testing.T) {
} }
} }
func TestAcquireGidExtCommunitiesSets(t *testing.T) { func TestAcquireExtCommunitiesSets(t *testing.T) {
c1 := []api.ExtCommunity{ c1 := []api.ExtCommunity{
{"ro", 5, 1}, {"ro", 5, 1},
{"ro", 5, 2}, {"ro", 5, 2},
@ -152,9 +130,9 @@ func TestAcquireGidExtCommunitiesSets(t *testing.T) {
p := NewExtCommunitiesSetPool() p := NewExtCommunitiesSetPool()
pc1, gid1 := p.AcquireGid(c1) pc1 := p.Acquire(c1)
pc2, gid2 := p.AcquireGid(c2) pc2 := p.Acquire(c2)
pc3, gid3 := p.AcquireGid(c3) pc3 := p.Acquire(c3)
if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) { if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) {
t.Error("expected c1 !== c2") t.Error("expected c1 !== c2")
@ -163,11 +141,7 @@ func TestAcquireGidExtCommunitiesSets(t *testing.T) {
if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) { if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) {
t.Error("expected pc1 == pc2") t.Error("expected pc1 == pc2")
} }
if gid1 != gid2 {
t.Error("expected gid1 == gid2")
}
fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3) fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3)
fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3) fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3)
fmt.Printf("gid1: %d, gid2: %d, gid3: %d\n", gid1, gid2, gid3)
} }

View File

@ -21,25 +21,15 @@ func NewIntListPool() *IntListPool {
} }
} }
// AcquireGid int list from pool and return with gid // Acquire int list from pool
func (p *IntListPool) AcquireGid(list []int) ([]int, uint64) { func (p *IntListPool) Acquire(list []int) []int {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if len(list) == 0 { if len(list) == 0 {
return p.root.value, p.root.gid // root return p.root.value // root
} }
v, c := p.root.traverse(p.counter+1, list, list) return p.root.traverse(list, list)
if c > p.counter {
p.counter = c
}
return v, c
}
// Acquire int list from pool without gid
func (p *IntListPool) Acquire(list []int) []int {
v, _ := p.AcquireGid(list)
return v
} }
// A StringListPool can be used for deduplicating lists // A StringListPool can be used for deduplicating lists
@ -61,11 +51,10 @@ func NewStringListPool() *StringListPool {
} }
} }
// AcquireGid aquires the string list pointer from the pool // Acquire the string list pointer from the pool.
// and also returns the gid. func (p *StringListPool) Acquire(list []string) []string {
func (p *StringListPool) AcquireGid(list []string) ([]string, uint64) {
if len(list) == 0 { if len(list) == 0 {
return p.root.value, p.root.gid return p.root.value
} }
// Make idenfier list // Make idenfier list
@ -81,11 +70,5 @@ func (p *StringListPool) AcquireGid(list []string) ([]string, uint64) {
id[i] = v id[i] = v
} }
return p.root.traverse(uint64(p.head), list, id) return p.root.traverse(list, id)
}
// Acquire aquires the string list pointer from the pool
func (p *StringListPool) Acquire(list []string) []string {
v, _ := p.AcquireGid(list)
return v
} }

View File

@ -13,12 +13,12 @@ func TestAcquireIntList(t *testing.T) {
p := NewIntListPool() p := NewIntListPool()
r1, gid1 := p.AcquireGid(a) r1 := p.Acquire(a)
p.Acquire(c) p.Acquire(c)
r2, gid2 := p.AcquireGid(b) r2 := p.Acquire(b)
log.Println("r1", r1, "gid1", gid1) log.Println("r1", r1)
log.Println("r2", r2, "gid2", gid2) log.Println("r2", r2)
if fmt.Sprintf("%p", a) == fmt.Sprintf("%p", b) { if fmt.Sprintf("%p", a) == fmt.Sprintf("%p", b) {
t.Error("lists should not be same pointer", fmt.Sprintf("%p %p", a, b)) t.Error("lists should not be same pointer", fmt.Sprintf("%p %p", a, b))
@ -26,17 +26,8 @@ func TestAcquireIntList(t *testing.T) {
if fmt.Sprintf("%p", r1) != fmt.Sprintf("%p", r2) { if fmt.Sprintf("%p", r1) != fmt.Sprintf("%p", r2) {
t.Error("lists should be same pointer", fmt.Sprintf("%p %p", r1, r2)) t.Error("lists should be same pointer", fmt.Sprintf("%p %p", r1, r2))
} }
if gid1 != gid2 {
t.Error("gid should be same, got:", gid1, gid2)
}
t.Log(fmt.Sprintf("Ptr: %p %p => %p %p", a, b, r1, r2)) t.Log(fmt.Sprintf("Ptr: %p %p => %p %p", a, b, r1, r2))
_, gid3 := p.AcquireGid(c)
if gid3 == gid1 {
t.Error("gid should not be same, got:", gid3, gid1)
}
t.Log("gid3", gid3, "gid1", gid1)
} }
func TestAcquireStringList(t *testing.T) { func TestAcquireStringList(t *testing.T) {
@ -45,9 +36,8 @@ func TestAcquireStringList(t *testing.T) {
e := []string{"foo", "bpf"} e := []string{"foo", "bpf"}
p2 := NewStringListPool() p2 := NewStringListPool()
x1, g1 := p2.AcquireGid(q) x1 := p2.Acquire(q)
x2, g2 := p2.AcquireGid(w) p2.Acquire(e)
x3, g3 := p2.AcquireGid(e) x2 := p2.Acquire(w)
fmt.Printf("Ptr: %p %p => %p %d %p %d \n", q, w, x1, g1, x2, g2) fmt.Printf("Ptr: %p %p => %p %p \n", q, w, x1, x2)
fmt.Printf("Ptr: %p => %p %d\n", e, x3, g3)
} }

View File

@ -5,7 +5,6 @@ type Node[T comparable, V any] struct {
children map[T]*Node[T, V] // map of children children map[T]*Node[T, V] // map of children
value V value V
final bool final bool
gid uint64
} }
// NewNode creates a new tree node // NewNode creates a new tree node
@ -19,7 +18,7 @@ func NewNode[T comparable, V any](value V) *Node[T, V] {
// traverse inserts a new node into the three if required // traverse inserts a new node into the three if required
// or returns the object if it already exists. // or returns the object if it already exists.
func (n *Node[T, V]) traverse(gid uint64, value V, tail []T) (V, uint64) { func (n *Node[T, V]) traverse(value V, tail []T) V {
id := tail[0] id := tail[0]
tail = tail[1:] tail = tail[1:]
@ -36,16 +35,15 @@ func (n *Node[T, V]) traverse(gid uint64, value V, tail []T) (V, uint64) {
if !child.final { if !child.final {
child.value = value child.value = value
child.final = true child.final = true
child.gid = gid
} }
return child.value, child.gid return child.value
} }
return child.traverse(gid, value, tail) return child.traverse(value, tail)
} }
// read returns the object if it exists or nil if not. // read returns the object if it exists or nil if not.
func (n *Node[T, V]) read(tail []T) (V, uint64) { func (n *Node[T, V]) read(tail []T) V {
id := tail[0] id := tail[0]
tail = tail[1:] tail = tail[1:]
@ -53,12 +51,12 @@ func (n *Node[T, V]) read(tail []T) (V, uint64) {
child, ok := n.children[id] child, ok := n.children[id]
if !ok { if !ok {
var zero V var zero V
return zero, 0 return zero
} }
// Set obj if required // Set obj if required
if len(tail) == 0 { if len(tail) == 0 {
return child.value, child.gid return child.value
} }
return child.read(tail) return child.read(tail)