Commit b0e6286f authored by Matt Joiner's avatar Matt Joiner
Browse files

Abstract out the traversal core logic

We want to reuse it for bootstrapping.
parent aa99bbb5
......@@ -7,7 +7,6 @@ import (
"net"
"github.com/anacrolix/missinggo/v2/conntrack"
"github.com/anacrolix/missinggo/v2/iter"
"github.com/anacrolix/stm"
"github.com/anacrolix/stm/stmutil"
"github.com/benbjohnson/immutable"
......@@ -28,8 +27,6 @@ type Announce struct {
doneVar *stm.Var
cancel func()
triedAddrs *stm.Var // Settish of krpc.NodeAddr.String
pending *stm.Var // How many transactions are still ongoing (int).
server *Server
infoHash int160 // Target
......@@ -41,8 +38,9 @@ type Announce struct {
// being NATed.
announcePortImplied bool
nodesPendingContact *stm.Var // Settish of addrMaybeId sorted by distance from the target
pendingAnnouncePeers *stm.Var // List of pendingAnnouncePeer
traversal traversal
}
type pendingAnnouncePeer struct {
......@@ -71,15 +69,14 @@ func (s *Server) Announce(infoHash [20]byte, port int, impliedPort bool) (*Annou
a := &Announce{
Peers: make(chan PeersValues, 100),
values: make(chan PeersValues),
triedAddrs: stm.NewVar(stmutil.NewSet()),
server: s,
infoHash: infoHashInt160,
announcePort: port,
announcePortImplied: impliedPort,
nodesPendingContact: stm.NewVar(nodesByDistance(infoHashInt160)),
pending: stm.NewVar(0),
numContacted: stm.NewVar(0),
pendingAnnouncePeers: stm.NewVar(immutable.NewList()),
traversal: newTraversal(infoHashInt160),
}
var ctx context.Context
ctx, a.cancel = context.WithCancel(context.Background())
......@@ -102,9 +99,7 @@ func (s *Server) Announce(infoHash [20]byte, port int, impliedPort bool) (*Annou
}
}()
for _, n := range startAddrs {
stm.Atomically(func(tx *stm.Tx) {
a.pendContact(n, tx)
})
stm.Atomically(a.pendContact(n))
}
go a.closer()
go a.nodeContactor()
......@@ -118,7 +113,7 @@ func (a *Announce) closer() {
return
}
tx.Assert(tx.Get(a.pending).(int) == 0)
tx.Assert(tx.Get(a.nodesPendingContact).(stmutil.Lenner).Len() == 0)
a.traversal.finished(tx)
tx.Assert(tx.Get(a.pendingAnnouncePeers).(stmutil.Lenner).Len() == 0)
})
}
......@@ -140,9 +135,6 @@ func (a *Announce) shouldContact(addr krpc.NodeAddr, tx *stm.Tx) bool {
if !validNodeAddr(addr.UDP()) {
return false
}
if tx.Get(a.triedAddrs).(stmutil.Settish).Contains(addr.String()) {
return false
}
if a.server.ipBlocked(addr.IP) {
return false
}
......@@ -157,9 +149,7 @@ func (a *Announce) completeContact() {
func (a *Announce) responseNode(node krpc.NodeInfo) {
i := int160FromByteArray(node.ID)
stm.Atomically(func(tx *stm.Tx) {
a.pendContact(addrMaybeId{node.Addr, &i}, tx)
})
stm.Atomically(a.pendContact(addrMaybeId{node.Addr, &i}))
}
// Announce to a peer, if appropriate.
......@@ -201,9 +191,10 @@ func (a *Announce) beginAnnouncePeer(tx *stm.Tx) {
}})
}
func finalizeCteh(cteh *conntrack.EntryHandle, writes int) {
func finalizeCteh(cteh *conntrack.EntryHandle, writes numWrites) {
if writes == 0 {
cteh.Forget()
panic("how to reverse rate limit?")
} else {
cteh.Done()
}
......@@ -213,7 +204,7 @@ func (a *Announce) getPeers(addr Addr, cteh *conntrack.EntryHandle) {
// log.Printf("sending get_peers to %v", node)
m, writes, err := a.server.getPeers(context.TODO(), addr, a.infoHash)
finalizeCteh(cteh, writes)
a.server.logger().Printf("Announce.server.getPeers result: m.Y=%v, writes=%v, err=%v", m.Y, writes, err)
a.server.logger().Printf("Announce.server.getPeers result: m.Y=%v, numWrites=%v, err=%v", m.Y, writes, err)
// log.Printf("get_peers response error from %v: %v", node, err)
// Register suggested nodes closer to the target info-hash.
if m.R != nil && m.SenderID() != nil {
......@@ -252,20 +243,19 @@ func (a *Announce) close() {
a.cancel()
}
func (a *Announce) pendContact(node addrMaybeId, tx *stm.Tx) {
if !a.shouldContact(node.Addr, tx) {
// log.Printf("shouldn't contact (pend): %v", node)
return
func (a *Announce) pendContact(node addrMaybeId) func(tx *stm.Tx) {
return func(tx *stm.Tx) {
if !a.shouldContact(node.Addr, tx) {
// log.Printf("shouldn't contact (pend): %v", node)
return
}
a.traversal.pendContact(node)(tx)
}
tx.Set(a.nodesPendingContact, tx.Get(a.nodesPendingContact).(stmutil.Settish).Add(node))
}
type txResT struct {
done bool
run func()
//contact bool
//addr Addr
//cteh *conntrack.EntryHandle
}
func (a *Announce) nodeContactor() {
......@@ -286,19 +276,11 @@ func (a *Announce) nodeContactor() {
}
func (a *Announce) beginGetPeers(tx *stm.Tx) {
npc := tx.Get(a.nodesPendingContact).(stmutil.Settish)
first, ok := iter.First(npc.Iter)
tx.Assert(ok)
addr := first.(addrMaybeId).Addr
tx.Set(a.nodesPendingContact, npc.Delete(first))
if !a.shouldContact(addr, tx) {
tx.Return(txResT{run: func() {}})
}
addr := a.traversal.nextAddr(tx)
cteh := a.server.config.ConnectionTracking.Allow(tx, a.server.connTrackEntryForAddr(NewAddr(addr.UDP())), "announce get_peers", -1)
tx.Assert(cteh != nil)
tx.Assert(a.server.sendLimit.AllowStm(tx))
tx.Set(a.numContacted, tx.Get(a.numContacted).(int)+1)
tx.Set(a.pending, tx.Get(a.pending).(int)+1)
tx.Set(a.triedAddrs, tx.Get(a.triedAddrs).(stmutil.Settish).Add(addr.String()))
tx.Return(txResT{run: func() { a.getPeers(NewAddr(addr.UDP()), cteh) }})
}
......@@ -9,25 +9,27 @@ import (
"github.com/anacrolix/dht/v2/krpc"
)
func int160WithBitSet(bit int) *int160 {
var i int160
i.bits[bit] = 1
return &i
}
var sampleAddrMaybeIds = []addrMaybeId{
addrMaybeId{},
addrMaybeId{Id: new(int160)},
addrMaybeId{Id: int160WithBitSet(13)},
addrMaybeId{Id: int160WithBitSet(12)},
addrMaybeId{Addr: krpc.NodeAddr{Port: 1}},
addrMaybeId{
Id: int160WithBitSet(14),
Addr: krpc.NodeAddr{Port: 1}},
}
func TestNodesByDistance(t *testing.T) {
a := nodesByDistance(int160{})
amis := []addrMaybeId{
addrMaybeId{},
addrMaybeId{Id: new(int160)},
addrMaybeId{Id: func() *int160 {
var i int160
i.bits[13] = 1
return &i
}()},
addrMaybeId{Id: func() *int160 {
var i int160
i.bits[12] = 1
return &i
}()},
addrMaybeId{Addr: krpc.NodeAddr{Port: 1}},
}
push := func(i int) {
a = a.Add(amis[i])
a = a.Add(sampleAddrMaybeIds[i])
}
push(4)
push(2)
......@@ -40,7 +42,7 @@ func TestNodesByDistance(t *testing.T) {
assert.True(t, ok)
assert.Contains(t, func() (ret []addrMaybeId) {
for _, i := range is {
ret = append(ret, amis[i])
ret = append(ret, sampleAddrMaybeIds[i])
}
return
}(), first)
......
......@@ -645,20 +645,34 @@ func (s *Server) connTrackEntryForAddr(a Addr) conntrack.Entry {
}
}
type numWrites int
func (s *Server) beginQuery(addr Addr, reason string, f func() numWrites) func(tx *stm.Tx) {
return func(tx *stm.Tx) {
tx.Assert(s.sendLimit.AllowStm(tx))
cteh := s.config.ConnectionTracking.Allow(tx, s.connTrackEntryForAddr(addr), reason, -1)
tx.Assert(cteh != nil)
tx.Return(func() {
writes := f()
finalizeCteh(cteh, writes)
})
}
}
func (s *Server) query(addr Addr, q string, a *krpc.MsgArgs, callback func(krpc.Msg, error)) error {
if callback == nil {
callback = func(krpc.Msg, error) {}
}
go func() {
cteh := s.config.ConnectionTracking.Wait(context.TODO(), s.connTrackEntryForAddr(addr), fmt.Sprintf("send dht query %q", q), -1)
s.sendLimit.Wait(context.TODO())
m, writes, err := s.queryContext(context.Background(), addr, q, a)
if writes > 0 {
cteh.Done()
} else {
cteh.Forget()
}
callback(m, err)
stm.Atomically(
s.beginQuery(addr, fmt.Sprintf("send dht query %q", q),
func() numWrites {
m, writes, err := s.queryContext(context.Background(), addr, q, a)
callback(m, err)
return writes
},
),
).(func())()
}()
return nil
}
......@@ -686,7 +700,7 @@ func (s *Server) makeQueryBytes(q string, a *krpc.MsgArgs, t string) []byte {
return b
}
func (s *Server) queryContext(ctx context.Context, addr Addr, q string, a *krpc.MsgArgs) (reply krpc.Msg, writes int, err error) {
func (s *Server) queryContext(ctx context.Context, addr Addr, q string, a *krpc.MsgArgs) (reply krpc.Msg, writes numWrites, err error) {
defer func(started time.Time) {
s.logger().WithValues(log.Debug).Printf("queryContext returned after %v (err=%v, reply.Y=%v, reply.E=%v)", time.Since(started), err, reply.Y, reply.E)
}(time.Now())
......@@ -727,7 +741,7 @@ func (s *Server) queryContext(ctx context.Context, addr Addr, q string, a *krpc.
return
}
func (s *Server) transactionQuerySender(sendCtx context.Context, sendErr chan<- error, b []byte, writes *int, addr Addr) {
func (s *Server) transactionQuerySender(sendCtx context.Context, sendErr chan<- error, b []byte, writes *numWrites, addr Addr) {
defer close(sendErr)
err := transactionSender(
sendCtx,
......@@ -765,7 +779,7 @@ func (s *Server) ping(node *net.UDPAddr, callback func(krpc.Msg, error)) error {
return s.query(NewAddr(node), "ping", nil, callback)
}
func (s *Server) announcePeer(node Addr, infoHash int160, port int, token string, impliedPort bool) (m krpc.Msg, writes int, err error) {
func (s *Server) announcePeer(node Addr, infoHash int160, port int, token string, impliedPort bool) (m krpc.Msg, writes numWrites, err error) {
if port == 0 && !impliedPort {
err = errors.New("no port specified")
return
......@@ -846,7 +860,7 @@ func (s *Server) Close() {
s.socket.Close()
}
func (s *Server) getPeers(ctx context.Context, addr Addr, infoHash int160) (krpc.Msg, int, error) {
func (s *Server) getPeers(ctx context.Context, addr Addr, infoHash int160) (krpc.Msg, numWrites, error) {
m, writes, err := s.queryContext(ctx, addr, "get_peers", &krpc.MsgArgs{
InfoHash: infoHash.AsByteArray(),
// TODO: Maybe IPv4-only Servers won't want IPv6 nodes?
......
package dht
import "fmt"
import (
"fmt"
"github.com/anacrolix/missinggo/v2/iter"
"github.com/anacrolix/stm"
"github.com/anacrolix/stm/stmutil"
"github.com/anacrolix/dht/v2/krpc"
)
type TraversalStats struct {
NumAddrsTried int
......@@ -10,3 +18,63 @@ type TraversalStats struct {
func (me TraversalStats) String() string {
return fmt.Sprintf("%#v", me)
}
// Prioritizes addrs to try by distance from target, disallowing repeat contacts.
type traversal struct {
targetInfohash int160
triedAddrs *stm.Var // Settish of krpc.NodeAddr.String
nodesPendingContact *stm.Var // Settish of addrMaybeId sorted by distance from the target
addrBestIds *stm.Var // Mappish Addr to best
}
func newTraversal(targetInfohash int160) traversal {
return traversal{
targetInfohash: targetInfohash,
triedAddrs: stm.NewVar(stmutil.NewSet()),
nodesPendingContact: stm.NewVar(nodesByDistance(targetInfohash)),
addrBestIds: stm.NewVar(stmutil.NewMap()),
}
}
func (t traversal) finished(tx *stm.Tx) {
tx.Assert(tx.Get(t.nodesPendingContact).(stmutil.Lenner).Len() == 0)
}
func (t traversal) pendContact(node addrMaybeId) func(*stm.Tx) {
return func(tx *stm.Tx) {
nodeAddrString := node.Addr.String()
if tx.Get(t.triedAddrs).(stmutil.Settish).Contains(nodeAddrString) {
return
}
addrBestIds := tx.Get(t.addrBestIds).(stmutil.Mappish)
nodesPendingContact := tx.Get(t.nodesPendingContact).(stmutil.Settish)
if _best, ok := addrBestIds.Get(nodeAddrString); ok {
if node.Id == nil {
return
}
best := _best.(*int160)
if best != nil && distance(*best, t.targetInfohash).Cmp(distance(*node.Id, t.targetInfohash)) <= 0 {
return
}
nodesPendingContact = nodesPendingContact.Delete(addrMaybeId{
Addr: node.Addr,
Id: best,
})
}
tx.Set(t.addrBestIds, addrBestIds.Set(nodeAddrString, node.Id))
nodesPendingContact = nodesPendingContact.Add(node)
tx.Set(t.nodesPendingContact, nodesPendingContact)
}
}
func (a traversal) nextAddr(tx *stm.Tx) krpc.NodeAddr {
npc := tx.Get(a.nodesPendingContact).(stmutil.Settish)
first, ok := iter.First(npc.Iter)
tx.Assert(ok)
addr := first.(addrMaybeId).Addr
addrString := addr.String()
tx.Set(a.nodesPendingContact, npc.Delete(first))
tx.Set(a.addrBestIds, tx.Get(a.addrBestIds).(stmutil.Mappish).Delete(addrString))
tx.Set(a.triedAddrs, tx.Get(a.triedAddrs).(stmutil.Settish).Add(addrString))
return addr
}
package dht
import (
"testing"
"github.com/anacrolix/stm"
"github.com/stretchr/testify/assert"
"github.com/anacrolix/dht/v2/krpc"
)
func TestTraversal(t *testing.T) {
var target int160
traversal := newTraversal(target)
assert.True(t, stm.WouldBlock(func(tx *stm.Tx) { traversal.nextAddr(tx) }))
assert.False(t, stm.WouldBlock(traversal.finished))
stm.Atomically(stm.Compose(func() (ret []func(tx *stm.Tx)) {
for _, v := range sampleAddrMaybeIds[2:6] {
ret = append(ret, traversal.pendContact(v))
}
return
}()...))
assert.False(t, stm.WouldBlock(func(tx *stm.Tx) { traversal.nextAddr(tx) }))
assert.True(t, stm.WouldBlock(traversal.finished))
pop := func(tx *stm.Tx) { tx.Return(traversal.nextAddr(tx)) }
var addrs []krpc.NodeAddr
for !stm.WouldBlock(pop) {
addrs = append(addrs, stm.Atomically(pop).(krpc.NodeAddr))
}
assert.False(t, stm.WouldBlock(traversal.finished))
t.Log(addrs)
assert.EqualValues(t, []krpc.NodeAddr{{Port: 1}, {}}, addrs)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment