Commit 908761b8 authored by Matt Joiner's avatar Matt Joiner
Browse files

Implement server bootstrap with STM and helpers

parent d40f9ecc
package dht
import (
"sync"
"log"
"sync/atomic"
"github.com/anacrolix/stm"
"github.com/anacrolix/dht/v2/krpc"
)
......@@ -12,36 +15,55 @@ func (s *Server) Bootstrap() (ts TraversalStats, err error) {
if err != nil {
return
}
var outstanding sync.WaitGroup
triedAddrs := newBloomFilterForTraversal()
var onAddr func(addr Addr)
onAddr = func(addr Addr) {
if triedAddrs.Test([]byte(addr.String())) {
return
}
ts.NumAddrsTried++
outstanding.Add(1)
triedAddrs.AddString(addr.String())
s.findNode(addr, s.id, func(m krpc.Msg, err error) {
defer outstanding.Done()
s.mu.Lock()
defer s.mu.Unlock()
if err != nil {
return
}
ts.NumResponses++
if r := m.R; r != nil {
r.ForAllNodes(func(ni krpc.NodeInfo) {
onAddr(NewAddr(ni.Addr.UDP()))
})
}
})
}
s.mu.Lock()
traversal := newTraversal(s.id)
for _, addr := range initialAddrs {
onAddr(NewAddr(addr.Addr.UDP()))
log.Println("pending", addr)
stm.Atomically(traversal.pendContact(addr))
}
outstanding := stm.NewVar(0)
for {
type txResT struct {
done bool
io func()
}
txRes := stm.Atomically(stm.Select(
func(tx *stm.Tx) interface{} {
addr := traversal.nextAddr(tx)
dhtAddr := NewAddr(addr.UDP())
tx.Set(outstanding, tx.Get(outstanding).(int)+1)
return txResT{
io: s.beginQuery(dhtAddr, "dht bootstrap find_node", func() numWrites {
atomic.AddInt64(&ts.NumAddrsTried, 1)
m, writes, err := s.findNode(dhtAddr, s.id)
if err == nil {
ts.NumResponses++
}
if r := m.R; r != nil {
r.ForAllNodes(func(ni krpc.NodeInfo) {
id := int160FromByteArray(ni.ID)
stm.Atomically(traversal.pendContact(addrMaybeId{
Addr: ni.Addr,
Id: &id,
}))
})
}
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
tx.Set(outstanding, tx.Get(outstanding).(int)-1)
log.Println("setting outstanding", tx.Get(outstanding))
}))
return writes
})(tx).(func()),
}
},
func(tx *stm.Tx) interface{} {
tx.Assert(tx.Get(outstanding).(int) == 0)
return txResT{done: true}
},
)).(txResT)
if txRes.done {
break
}
go txRes.io()
}
s.mu.Unlock()
outstanding.Wait()
return
}
......@@ -279,6 +279,7 @@ func TestBootstrapRace(t *testing.T) {
Conn: &serverPc,
StartingNodes: addrResolver(remotePc.LocalAddr().String()),
QueryResendDelay: func() time.Duration { return 0 },
Logger: log.Default,
})
require.NoError(t, err)
defer s.Close()
......
......@@ -5,7 +5,7 @@ require (
github.com/anacrolix/log v0.3.1-0.20190913000754-831e4ffe0174
github.com/anacrolix/missinggo v1.2.1
github.com/anacrolix/missinggo/v2 v2.2.1-0.20191103010835-12360f38ced0
github.com/anacrolix/stm v0.1.1-0.20191105075537-443c0b33d649
github.com/anacrolix/stm v0.1.1-0.20191106051447-e749ba3531cf
github.com/anacrolix/sync v0.2.0
github.com/anacrolix/tagflag v1.0.1
github.com/anacrolix/torrent v1.7.1
......
......@@ -52,6 +52,8 @@ github.com/anacrolix/stm v0.1.0 h1:B/Kt3e4+0uqJoLcNZFW69cCBASok6WxX9CEhz9PqIPM=
github.com/anacrolix/stm v0.1.0/go.mod h1:ZKz7e7ERWvP0KgL7WXfRjBXHNRhlVRlbBQecqFtPq+A=
github.com/anacrolix/stm v0.1.1-0.20191105075537-443c0b33d649 h1:ZMMjQrpZH1cpbd7PWRsBWCoJNxPZRvK5VZNCFv6dtr8=
github.com/anacrolix/stm v0.1.1-0.20191105075537-443c0b33d649/go.mod h1:zoVQRvSiGjGoTmbM0vSLIiaKjWtNPeTvXUSdJQA4hsg=
github.com/anacrolix/stm v0.1.1-0.20191106051447-e749ba3531cf h1:xFI/MP4FRmRzlqliIoefvViq1jg/Ud0Pdx8q0IwYh4k=
github.com/anacrolix/stm v0.1.1-0.20191106051447-e749ba3531cf/go.mod h1:zoVQRvSiGjGoTmbM0vSLIiaKjWtNPeTvXUSdJQA4hsg=
github.com/anacrolix/sync v0.0.0-20171108081538-eee974e4f8c1/go.mod h1:+u91KiUuf0lyILI6x3n/XrW7iFROCZCG+TjgK8nW52w=
github.com/anacrolix/sync v0.0.0-20180611022320-3c4cb11f5a01 h1:14t4kCoWXaUXrHErRD0bLMNolOE50nyPA0gO8+J3hP8=
github.com/anacrolix/sync v0.0.0-20180611022320-3c4cb11f5a01/go.mod h1:+u91KiUuf0lyILI6x3n/XrW7iFROCZCG+TjgK8nW52w=
......
......@@ -819,18 +819,17 @@ func (s *Server) addResponseNodes(d krpc.Msg) {
}
// Sends a find_node query to addr. targetID is the node we're looking for.
func (s *Server) findNode(addr Addr, targetID int160, callback func(krpc.Msg, error)) (err error) {
return s.query(addr, "find_node", &krpc.MsgArgs{
func (s *Server) findNode(addr Addr, targetID int160) (krpc.Msg, numWrites, error) {
m, writes, err := s.queryContext(context.TODO(), addr, "find_node", &krpc.MsgArgs{
Target: targetID.AsByteArray(),
Want: []krpc.Want{krpc.WantNodes, krpc.WantNodes6},
}, func(m krpc.Msg, err error) {
// Scrape peers from the response to put in the server's table before
// handing the response back to the caller.
s.mu.Lock()
s.addResponseNodes(m)
s.mu.Unlock()
callback(m, err)
})
// Scrape peers from the response to put in the server's table before
// handing the response back to the caller.
s.mu.Lock()
s.addResponseNodes(m)
s.mu.Unlock()
return m, writes, err
}
// Returns how many nodes are in the node table.
......
......@@ -11,8 +11,8 @@ import (
)
type TraversalStats struct {
NumAddrsTried int
NumResponses int
NumAddrsTried int64
NumResponses int64
}
func (me TraversalStats) String() string {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment