From 98847ba0cbe71f5e1028eac9eff8fe8bc121d7de Mon Sep 17 00:00:00 2001 From: Annika Hannig Date: Tue, 16 Jan 2024 17:07:09 +0100 Subject: [PATCH] use pointers as identifier. --- pkg/pools/communities.go | 107 ++++++++++++---------------------- pkg/pools/communities_test.go | 60 ++++++------------- pkg/pools/lists.go | 33 +++-------- pkg/pools/lists_test.go | 26 +++------ pkg/pools/node.go | 14 ++--- 5 files changed, 75 insertions(+), 165 deletions(-) diff --git a/pkg/pools/communities.go b/pkg/pools/communities.go index 418b8a9..44bd8dd 100644 --- a/pkg/pools/communities.go +++ b/pkg/pools/communities.go @@ -2,7 +2,9 @@ package pools import ( "math" + "reflect" "sync" + "unsafe" "github.com/alice-lg/alice-lg/pkg/api" ) @@ -11,8 +13,7 @@ import ( // This works with large and standard communities. For extended // communities, use the ExtCommunityPool. type CommunitiesPool struct { - root *Node[int, api.Community] - counter uint64 + root *Node[int, api.Community] sync.RWMutex } @@ -23,32 +24,22 @@ func NewCommunitiesPool() *CommunitiesPool { } } -// AcquireGid acquires a single bgp community with gid -func (p *CommunitiesPool) AcquireGid(c api.Community) (api.Community, uint64) { +// Acquire a single bgp community +func (p *CommunitiesPool) Acquire(c api.Community) api.Community { p.Lock() defer p.Unlock() if len(c) == 0 { - return p.root.value, p.root.gid + return p.root.value } - v, gid := p.root.traverse(p.counter+1, c, c) - if gid > p.counter { - p.counter = gid - } - return v, gid -} - -// Acquire a single bgp community without gid -func (p *CommunitiesPool) Acquire(c api.Community) api.Community { - v, _ := p.AcquireGid(c) - return v + return p.root.traverse(c, c) } // Read a single bgp community -func (p *CommunitiesPool) Read(c api.Community) (api.Community, uint64) { +func (p *CommunitiesPool) Read(c api.Community) api.Community { p.RLock() defer p.RUnlock() if len(c) == 0 { - return p.root.value, p.root.gid + return p.root.value } return p.root.read(c) } @@ -57,8 +48,7 @@ func (p *CommunitiesPool) Read(c api.Community) (api.Community, uint64) { // (Large and default. The ext communities representation right now // makes problems and need to be fixed. TODO.) type CommunitiesSetPool struct { - root *Node[uint64, []api.Community] - counter uint64 + root *Node[unsafe.Pointer, []api.Community] sync.Mutex } @@ -66,47 +56,34 @@ type CommunitiesSetPool struct { // of BGP communities. func NewCommunitiesSetPool() *CommunitiesSetPool { return &CommunitiesSetPool{ - root: NewNode[uint64, []api.Community]([]api.Community{}), + root: NewNode[unsafe.Pointer, []api.Community]([]api.Community{}), } } -// AcquireGid acquires a list of bgp communities and returns a gid -func (p *CommunitiesSetPool) AcquireGid( - communities []api.Community, -) ([]api.Community, uint64) { - p.Lock() - defer p.Unlock() - // Make identification list by using the pointer address - // of the deduplicated community as ID - ids := make([]uint64, len(communities)) - set := make([]api.Community, len(communities)) - for i, comm := range communities { - ptr, gid := Communities.AcquireGid(comm) - ids[i] = gid - set[i] = ptr - } - if len(ids) == 0 { - return p.root.value, p.root.gid - } - v, id := p.root.traverse(p.counter+1, set, ids) - if id > p.counter { - p.counter = id - } - return v, id -} - // Acquire a list of bgp communities func (p *CommunitiesSetPool) Acquire( communities []api.Community, ) []api.Community { - v, _ := p.AcquireGid(communities) - return v + p.Lock() + defer p.Unlock() + // Make identification list by using the pointer address + // of the deduplicated community as ID + ids := make([]unsafe.Pointer, len(communities)) + set := make([]api.Community, len(communities)) + for i, comm := range communities { + ptr := Communities.Acquire(comm) + ids[i] = reflect.ValueOf(ptr).UnsafePointer() + set[i] = ptr + } + if len(ids) == 0 { + return p.root.value + } + return p.root.traverse(set, ids) } // ExtCommunitiesSetPool is for deduplicating a list of ext. BGP communities type ExtCommunitiesSetPool struct { - root *Node[uint64, []api.ExtCommunity] - counter uint64 + root *Node[unsafe.Pointer, []api.ExtCommunity] sync.Mutex } @@ -114,7 +91,7 @@ type ExtCommunitiesSetPool struct { // of BGP communities. func NewExtCommunitiesSetPool() *ExtCommunitiesSetPool { return &ExtCommunitiesSetPool{ - root: NewNode[uint64, []api.ExtCommunity]([]api.ExtCommunity{}), + root: NewNode[unsafe.Pointer, []api.ExtCommunity]([]api.ExtCommunity{}), } } @@ -126,37 +103,25 @@ func extPrefixToInt(s string) int { return v } -// AcquireGid acquires a list of ext bgp communities -func (p *ExtCommunitiesSetPool) AcquireGid( +// Acquire a list of ext bgp communities +func (p *ExtCommunitiesSetPool) Acquire( communities []api.ExtCommunity, -) ([]api.ExtCommunity, uint64) { +) []api.ExtCommunity { p.Lock() defer p.Unlock() // Make identification list - ids := make([]uint64, len(communities)) + ids := make([]unsafe.Pointer, len(communities)) for i, comm := range communities { r := extPrefixToInt(comm[0].(string)) icomm := []int{r, comm[1].(int), comm[2].(int)} // get community identifier - _, gid := ExtCommunities.AcquireGid(icomm) - ids[i] = gid + ptr := ExtCommunities.Acquire(icomm) + ids[i] = reflect.ValueOf(ptr).UnsafePointer() } if len(ids) == 0 { - return p.root.value, p.root.gid + return p.root.value } - v, id := p.root.traverse(p.counter+1, communities, ids) - if id > p.counter { - p.counter = id - } - return v, id -} - -// Acquire a list of ext bgp communities -func (p *ExtCommunitiesSetPool) Acquire( - communities []api.ExtCommunity, -) []api.ExtCommunity { - v, _ := p.AcquireGid(communities) - return v + return p.root.traverse(communities, ids) } diff --git a/pkg/pools/communities_test.go b/pkg/pools/communities_test.go index 3003b10..49777b6 100644 --- a/pkg/pools/communities_test.go +++ b/pkg/pools/communities_test.go @@ -9,16 +9,16 @@ import ( "github.com/alice-lg/alice-lg/pkg/api" ) -func TestAcquireGidCommunity(t *testing.T) { +func TestAcquireCommunity(t *testing.T) { c1 := api.Community{2342, 5, 1} c2 := api.Community{2342, 5, 1} c3 := api.Community{2342, 5} p := NewCommunitiesPool() - pc1, gid1 := p.AcquireGid(c1) - pc2, gid2 := p.AcquireGid(c2) - pc3, gid3 := p.AcquireGid(c3) + pc1 := p.Acquire(c1) + pc2 := p.Acquire(c2) + pc3 := p.Acquire(c3) if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) { t.Error("expected c1 !== c2") @@ -28,19 +28,10 @@ func TestAcquireGidCommunity(t *testing.T) { t.Error("expected pc1 == pc2") } - if gid1 != gid2 { - t.Error("expected gid1 == gid2") - } - - if gid1 == gid3 { - t.Error("expected gid1 != gid3") - } - fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3) fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3) log.Println(c3, pc3) - log.Println(gid1, gid2, gid3) } func TestCommunityRead(t *testing.T) { @@ -50,29 +41,22 @@ func TestCommunityRead(t *testing.T) { p := NewCommunitiesPool() - pc1, gid1 := p.AcquireGid(c1) - pc2, gid2 := p.Read(c2) - pc3, gid3 := p.Read(c3) + pc1 := p.Acquire(c1) + pc2 := p.Read(c2) + pc3 := p.Read(c3) fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3) - fmt.Printf("gid1: %d, gid2: %d, gid3: %d\n", gid1, gid2, gid3) if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) { t.Error("expected pc1 == pc2") } - if gid1 != gid2 { - t.Error("expected gid1 == gid2") - } if pc3 != nil { t.Error("expected pc3 == nil, got", pc3) } - if gid3 != 0 { - t.Error("expected gid3 == 0, got", gid3) - } } -func TestAcquireGidCommunitiesSets(t *testing.T) { +func TestAcquireCommunitiesSets(t *testing.T) { c1 := []api.Community{ {2342, 5, 1}, {2342, 5, 2}, @@ -91,9 +75,9 @@ func TestAcquireGidCommunitiesSets(t *testing.T) { p := NewCommunitiesSetPool() - pc1, gid1 := p.AcquireGid(c1) - pc2, gid2 := p.AcquireGid(c2) - pc3, gid3 := p.AcquireGid(c3) + pc1 := p.Acquire(c1) + pc2 := p.Acquire(c2) + pc3 := p.Acquire(c3) if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) { t.Error("expected c1 !== c2") @@ -102,14 +86,9 @@ func TestAcquireGidCommunitiesSets(t *testing.T) { if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) { t.Error("expected pc1 == pc2") } - if gid1 != gid2 { - t.Error("expected gid1 == gid2") - } fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3) fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3) - - t.Logf("gid1: %d, gid2: %d, gid3: %d\n", gid1, gid2, gid3) } func TestSetCommunityIdentity(t *testing.T) { @@ -119,11 +98,10 @@ func TestSetCommunityIdentity(t *testing.T) { {2341, 1, 1}, } - pset, gid1 := CommunitiesSets.AcquireGid(set) - pval, gid2 := Communities.AcquireGid(api.Community{2341, 6, 2}) + pset := CommunitiesSets.Acquire(set) + pval := Communities.Acquire(api.Community{2341, 6, 2}) fmt.Printf("set: %p, pset[1]: %p, pval: %p\n", set, pset[1], pval) - fmt.Printf("gid1: %d, gid2: %d\n", gid1, gid2) p1 := reflect.ValueOf(pset[1]).UnsafePointer() p2 := reflect.ValueOf(pval).UnsafePointer() @@ -133,7 +111,7 @@ func TestSetCommunityIdentity(t *testing.T) { } } -func TestAcquireGidExtCommunitiesSets(t *testing.T) { +func TestAcquireExtCommunitiesSets(t *testing.T) { c1 := []api.ExtCommunity{ {"ro", 5, 1}, {"ro", 5, 2}, @@ -152,9 +130,9 @@ func TestAcquireGidExtCommunitiesSets(t *testing.T) { p := NewExtCommunitiesSetPool() - pc1, gid1 := p.AcquireGid(c1) - pc2, gid2 := p.AcquireGid(c2) - pc3, gid3 := p.AcquireGid(c3) + pc1 := p.Acquire(c1) + pc2 := p.Acquire(c2) + pc3 := p.Acquire(c3) if fmt.Sprintf("%p", c1) == fmt.Sprintf("%p", c2) { t.Error("expected c1 !== c2") @@ -163,11 +141,7 @@ func TestAcquireGidExtCommunitiesSets(t *testing.T) { if fmt.Sprintf("%p", pc1) != fmt.Sprintf("%p", pc2) { t.Error("expected pc1 == pc2") } - if gid1 != gid2 { - t.Error("expected gid1 == gid2") - } fmt.Printf("c1: %p, c2: %p, c3: %p\n", c1, c2, c3) fmt.Printf("pc1: %p, pc2: %p, pc3: %p\n", pc1, pc2, pc3) - fmt.Printf("gid1: %d, gid2: %d, gid3: %d\n", gid1, gid2, gid3) } diff --git a/pkg/pools/lists.go b/pkg/pools/lists.go index c775b61..cbc002e 100644 --- a/pkg/pools/lists.go +++ b/pkg/pools/lists.go @@ -21,25 +21,15 @@ func NewIntListPool() *IntListPool { } } -// AcquireGid int list from pool and return with gid -func (p *IntListPool) AcquireGid(list []int) ([]int, uint64) { +// Acquire int list from pool +func (p *IntListPool) Acquire(list []int) []int { p.Lock() defer p.Unlock() if len(list) == 0 { - return p.root.value, p.root.gid // root + return p.root.value // root } - v, c := p.root.traverse(p.counter+1, list, list) - if c > p.counter { - p.counter = c - } - return v, c -} - -// Acquire int list from pool without gid -func (p *IntListPool) Acquire(list []int) []int { - v, _ := p.AcquireGid(list) - return v + return p.root.traverse(list, list) } // A StringListPool can be used for deduplicating lists @@ -61,11 +51,10 @@ func NewStringListPool() *StringListPool { } } -// AcquireGid aquires the string list pointer from the pool -// and also returns the gid. -func (p *StringListPool) AcquireGid(list []string) ([]string, uint64) { +// Acquire the string list pointer from the pool. +func (p *StringListPool) Acquire(list []string) []string { if len(list) == 0 { - return p.root.value, p.root.gid + return p.root.value } // Make idenfier list @@ -81,11 +70,5 @@ func (p *StringListPool) AcquireGid(list []string) ([]string, uint64) { id[i] = v } - return p.root.traverse(uint64(p.head), list, id) -} - -// Acquire aquires the string list pointer from the pool -func (p *StringListPool) Acquire(list []string) []string { - v, _ := p.AcquireGid(list) - return v + return p.root.traverse(list, id) } diff --git a/pkg/pools/lists_test.go b/pkg/pools/lists_test.go index 664771f..5a5b80f 100644 --- a/pkg/pools/lists_test.go +++ b/pkg/pools/lists_test.go @@ -13,12 +13,12 @@ func TestAcquireIntList(t *testing.T) { p := NewIntListPool() - r1, gid1 := p.AcquireGid(a) + r1 := p.Acquire(a) p.Acquire(c) - r2, gid2 := p.AcquireGid(b) + r2 := p.Acquire(b) - log.Println("r1", r1, "gid1", gid1) - log.Println("r2", r2, "gid2", gid2) + log.Println("r1", r1) + log.Println("r2", r2) if fmt.Sprintf("%p", a) == fmt.Sprintf("%p", b) { t.Error("lists should not be same pointer", fmt.Sprintf("%p %p", a, b)) @@ -26,17 +26,8 @@ func TestAcquireIntList(t *testing.T) { if fmt.Sprintf("%p", r1) != fmt.Sprintf("%p", r2) { t.Error("lists should be same pointer", fmt.Sprintf("%p %p", r1, r2)) } - if gid1 != gid2 { - t.Error("gid should be same, got:", gid1, gid2) - } t.Log(fmt.Sprintf("Ptr: %p %p => %p %p", a, b, r1, r2)) - - _, gid3 := p.AcquireGid(c) - if gid3 == gid1 { - t.Error("gid should not be same, got:", gid3, gid1) - } - t.Log("gid3", gid3, "gid1", gid1) } func TestAcquireStringList(t *testing.T) { @@ -45,9 +36,8 @@ func TestAcquireStringList(t *testing.T) { e := []string{"foo", "bpf"} p2 := NewStringListPool() - x1, g1 := p2.AcquireGid(q) - x2, g2 := p2.AcquireGid(w) - x3, g3 := p2.AcquireGid(e) - fmt.Printf("Ptr: %p %p => %p %d %p %d \n", q, w, x1, g1, x2, g2) - fmt.Printf("Ptr: %p => %p %d\n", e, x3, g3) + x1 := p2.Acquire(q) + p2.Acquire(e) + x2 := p2.Acquire(w) + fmt.Printf("Ptr: %p %p => %p %p \n", q, w, x1, x2) } diff --git a/pkg/pools/node.go b/pkg/pools/node.go index 5ed67ae..63fd7f8 100644 --- a/pkg/pools/node.go +++ b/pkg/pools/node.go @@ -5,7 +5,6 @@ type Node[T comparable, V any] struct { children map[T]*Node[T, V] // map of children value V final bool - gid uint64 } // NewNode creates a new tree node @@ -19,7 +18,7 @@ func NewNode[T comparable, V any](value V) *Node[T, V] { // traverse inserts a new node into the three if required // or returns the object if it already exists. -func (n *Node[T, V]) traverse(gid uint64, value V, tail []T) (V, uint64) { +func (n *Node[T, V]) traverse(value V, tail []T) V { id := tail[0] tail = tail[1:] @@ -36,16 +35,15 @@ func (n *Node[T, V]) traverse(gid uint64, value V, tail []T) (V, uint64) { if !child.final { child.value = value child.final = true - child.gid = gid } - return child.value, child.gid + return child.value } - return child.traverse(gid, value, tail) + return child.traverse(value, tail) } // read returns the object if it exists or nil if not. -func (n *Node[T, V]) read(tail []T) (V, uint64) { +func (n *Node[T, V]) read(tail []T) V { id := tail[0] tail = tail[1:] @@ -53,12 +51,12 @@ func (n *Node[T, V]) read(tail []T) (V, uint64) { child, ok := n.children[id] if !ok { var zero V - return zero, 0 + return zero } // Set obj if required if len(tail) == 0 { - return child.value, child.gid + return child.value } return child.read(tail)