Commit de77e1a1 authored by Matt Joiner's avatar Matt Joiner
Browse files

Do announce announce_peer followup using STM too

parent 1dba080f
......@@ -10,6 +10,7 @@ import (
"github.com/anacrolix/missinggo/v2/iter"
"github.com/anacrolix/stm"
"github.com/anacrolix/stm/stmutil"
"github.com/benbjohnson/immutable"
"github.com/willf/bloom"
"github.com/anacrolix/dht/v2/krpc"
......@@ -33,14 +34,20 @@ type Announce struct {
server *Server
infoHash int160 // Target
// Count of (probably) distinct addresses we've sent get_peers requests to.
numContacted *stm.Var
numContacted *stm.Var // int
// The torrent port that we're announcing.
announcePort int
// The torrent port should be determined by the receiver in case we're
// being NATed.
announcePortImplied bool
nodesPendingContact *stm.Var // Settish of addrMaybeId sorted by distance from the target
nodesPendingContact *stm.Var // Settish of addrMaybeId sorted by distance from the target
pendingAnnouncePeers *stm.Var // List of pendingAnnouncePeer
}
type pendingAnnouncePeer struct {
Addr
token string
}
// Returns the number of distinct remote addresses the announce has queried.
......@@ -62,16 +69,17 @@ func (s *Server) Announce(infoHash [20]byte, port int, impliedPort bool) (*Annou
}
infoHashInt160 := int160FromByteArray(infoHash)
a := &Announce{
Peers: make(chan PeersValues, 100),
values: make(chan PeersValues),
triedAddrs: stm.NewVar(stmutil.NewSet()),
server: s,
infoHash: infoHashInt160,
announcePort: port,
announcePortImplied: impliedPort,
nodesPendingContact: stm.NewVar(nodesByDistance(infoHashInt160)),
pending: stm.NewVar(0),
numContacted: stm.NewVar(0),
Peers: make(chan PeersValues, 100),
values: make(chan PeersValues),
triedAddrs: stm.NewVar(stmutil.NewSet()),
server: s,
infoHash: infoHashInt160,
announcePort: port,
announcePortImplied: impliedPort,
nodesPendingContact: stm.NewVar(nodesByDistance(infoHashInt160)),
pending: stm.NewVar(0),
numContacted: stm.NewVar(0),
pendingAnnouncePeers: stm.NewVar(immutable.NewList()),
}
var ctx context.Context
ctx, a.cancel = context.WithCancel(context.Background())
......@@ -111,6 +119,7 @@ func (a *Announce) closer() {
}
tx.Assert(tx.Get(a.pending).(int) == 0)
tx.Assert(tx.Get(a.nodesPendingContact).(stmutil.Lenner).Len() == 0)
tx.Assert(tx.Get(a.pendingAnnouncePeers).(stmutil.Lenner).Len() == 0)
})
}
......@@ -161,19 +170,49 @@ func (a *Announce) maybeAnnouncePeer(to Addr, token *string, peerId *krpc.ID) {
if !a.server.config.NoSecurity && (peerId == nil || !NodeIdSecure(*peerId, to.IP())) {
return
}
a.server.mu.Lock()
defer a.server.mu.Unlock()
a.server.announcePeer(to, a.infoHash, a.announcePort, *token, a.announcePortImplied, nil)
stm.Atomically(func(tx *stm.Tx) {
tx.Set(a.pendingAnnouncePeers, tx.Get(a.pendingAnnouncePeers).(stmutil.List).Append(pendingAnnouncePeer{
Addr: to,
token: *token,
}))
})
//a.server.announcePeer(to, a.infoHash, a.announcePort, *token, a.announcePortImplied, nil)
}
func (a *Announce) getPeers(addr Addr, cteh *conntrack.EntryHandle) {
// log.Printf("sending get_peers to %v", node)
m, writes, err := a.server.getPeers(context.TODO(), addr, a.infoHash)
func (a *Announce) announcePeer(peer pendingAnnouncePeer, cteh *conntrack.EntryHandle) {
_, writes, _ := a.server.announcePeer(peer.Addr, a.infoHash, a.announcePort, peer.token, a.announcePortImplied)
finalizeCteh(cteh, writes)
}
func (a *Announce) beginAnnouncePeer(tx *stm.Tx) {
l := tx.Get(a.pendingAnnouncePeers).(stmutil.List)
tx.Assert(l.Len() != 0)
x := l.Get(0).(pendingAnnouncePeer)
tx.Assert(a.server.sendLimit.AllowStm(tx))
cteh := a.server.config.ConnectionTracking.Allow(tx, a.server.connTrackEntryForAddr(x.Addr), "dht announce announce_peer", -1)
tx.Assert(cteh != nil)
tx.Set(a.pending, tx.Get(a.pending).(int)+1)
tx.Set(a.pendingAnnouncePeers, l.Slice(1, l.Len()))
tx.Return(txResT{run: func() {
a.announcePeer(x, cteh)
stm.Atomically(func(tx *stm.Tx) {
tx.Set(a.pending, tx.Get(a.pending).(int)-1)
})
}})
}
func finalizeCteh(cteh *conntrack.EntryHandle, writes int) {
if writes == 0 {
cteh.Forget()
} else {
cteh.Done()
}
}
func (a *Announce) getPeers(addr Addr, cteh *conntrack.EntryHandle) {
// log.Printf("sending get_peers to %v", node)
m, writes, err := a.server.getPeers(context.TODO(), addr, a.infoHash)
finalizeCteh(cteh, writes)
a.server.logger().Printf("Announce.server.getPeers result: m.Y=%v, writes=%v, err=%v", m.Y, writes, err)
// log.Printf("get_peers response error from %v: %v", node, err)
// Register suggested nodes closer to the target info-hash.
......@@ -221,39 +260,45 @@ func (a *Announce) pendContact(node addrMaybeId, tx *stm.Tx) {
tx.Set(a.nodesPendingContact, tx.Get(a.nodesPendingContact).(stmutil.Settish).Add(node))
}
type txResT struct {
done bool
run func()
//contact bool
//addr Addr
//cteh *conntrack.EntryHandle
}
func (a *Announce) nodeContactor() {
for {
type txResT struct {
done bool
contact bool
addr Addr
cteh *conntrack.EntryHandle
}
txRes := stm.Atomically(func(tx *stm.Tx) {
if tx.Get(a.doneVar).(bool) {
txRes := stm.Atomically(stm.Select(
func(tx *stm.Tx) {
tx.Assert(tx.Get(a.doneVar).(bool))
tx.Return(txResT{done: true})
}
npc := tx.Get(a.nodesPendingContact).(stmutil.Settish)
first, ok := iter.First(npc.Iter)
tx.Assert(ok)
addr := first.(addrMaybeId).Addr
tx.Set(a.nodesPendingContact, npc.Delete(first))
if !a.shouldContact(addr, tx) {
tx.Return(txResT{})
}
cteh := a.server.config.ConnectionTracking.Allow(tx, a.server.connTrackEntryForAddr(NewAddr(addr.UDP())), "announce get_peers", -1)
tx.Assert(cteh != nil)
tx.Assert(a.server.sendLimit.AllowStm(tx))
tx.Set(a.numContacted, tx.Get(a.numContacted).(int)+1)
tx.Set(a.pending, tx.Get(a.pending).(int)+1)
tx.Set(a.triedAddrs, tx.Get(a.triedAddrs).(stmutil.Settish).Add(addr.String()))
tx.Return(txResT{addr: NewAddr(addr.UDP()), cteh: cteh, contact: true})
}).(txResT)
},
a.beginGetPeers,
a.beginAnnouncePeer,
)).(txResT)
if txRes.done {
break
}
if txRes.contact {
go a.getPeers(txRes.addr, txRes.cteh)
}
go txRes.run()
}
}
func (a *Announce) beginGetPeers(tx *stm.Tx) {
npc := tx.Get(a.nodesPendingContact).(stmutil.Settish)
first, ok := iter.First(npc.Iter)
tx.Assert(ok)
addr := first.(addrMaybeId).Addr
tx.Set(a.nodesPendingContact, npc.Delete(first))
if !a.shouldContact(addr, tx) {
tx.Return(txResT{})
}
cteh := a.server.config.ConnectionTracking.Allow(tx, a.server.connTrackEntryForAddr(NewAddr(addr.UDP())), "announce get_peers", -1)
tx.Assert(cteh != nil)
tx.Assert(a.server.sendLimit.AllowStm(tx))
tx.Set(a.numContacted, tx.Get(a.numContacted).(int)+1)
tx.Set(a.pending, tx.Get(a.pending).(int)+1)
tx.Set(a.triedAddrs, tx.Get(a.triedAddrs).(stmutil.Settish).Add(addr.String()))
tx.Return(txResT{run: func() { a.getPeers(NewAddr(addr.UDP()), cteh) }})
}
......@@ -650,7 +650,7 @@ func (s *Server) query(addr Addr, q string, a *krpc.MsgArgs, callback func(krpc.
callback = func(krpc.Msg, error) {}
}
go func() {
cteh := s.config.ConnectionTracking.Wait(context.TODO(), s.connTrackEntryForAddr(addr), "send dht query", -1)
cteh := s.config.ConnectionTracking.Wait(context.TODO(), s.connTrackEntryForAddr(addr), fmt.Sprintf("send dht query %q", q), -1)
s.sendLimit.Wait(context.TODO())
m, writes, err := s.queryContext(context.Background(), addr, q, a)
if writes > 0 {
......@@ -765,27 +765,31 @@ func (s *Server) ping(node *net.UDPAddr, callback func(krpc.Msg, error)) error {
return s.query(NewAddr(node), "ping", nil, callback)
}
func (s *Server) announcePeer(node Addr, infoHash int160, port int, token string, impliedPort bool, callback func(krpc.Msg, error)) error {
func (s *Server) announcePeer(node Addr, infoHash int160, port int, token string, impliedPort bool) (m krpc.Msg, writes int, err error) {
if port == 0 && !impliedPort {
return errors.New("nothing to announce")
err = errors.New("no port specified")
return
}
return s.query(node, "announce_peer", &krpc.MsgArgs{
ImpliedPort: impliedPort,
InfoHash: infoHash.AsByteArray(),
Port: &port,
Token: token,
}, func(m krpc.Msg, err error) {
if callback != nil {
go callback(m, err)
}
if err := m.Error(); err != nil {
announceErrors.Add(1)
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.stats.SuccessfulOutboundAnnouncePeerQueries++
})
m, writes, err = s.queryContext(
context.TODO(), node, "announce_peer",
&krpc.MsgArgs{
ImpliedPort: impliedPort,
InfoHash: infoHash.AsByteArray(),
Port: &port,
Token: token,
},
)
if err != nil {
return
}
if err = m.Error(); err != nil {
announceErrors.Add(1)
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.stats.SuccessfulOutboundAnnouncePeerQueries++
return
}
// Add response nodes to node table.
......@@ -894,7 +898,8 @@ func (s *Server) Close() {
func (s *Server) getPeers(ctx context.Context, addr Addr, infoHash int160) (krpc.Msg, int, error) {
m, writes, err := s.queryContext(ctx, addr, "get_peers", &krpc.MsgArgs{
InfoHash: infoHash.AsByteArray(),
Want: []krpc.Want{krpc.WantNodes, krpc.WantNodes6},
// TODO: Maybe IPv4-only Servers won't want IPv6 nodes?
Want: []krpc.Want{krpc.WantNodes, krpc.WantNodes6},
})
s.mu.Lock()
defer s.mu.Unlock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment