Commit e40f4bee authored by Matt Joiner's avatar Matt Joiner
Browse files

Extract Announce.beginQuery and use stm explicit returns

parent 45341231
......@@ -103,14 +103,14 @@ func (s *Server) Announce(infoHash [20]byte, port int, impliedPort bool) (*Annou
func (a *Announce) closer() {
defer a.cancel()
stm.Atomically(func(tx *stm.Tx) {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
if tx.Get(a.doneVar).(bool) {
return
}
tx.Assert(tx.Get(a.pending).(int) == 0)
a.traversal.finished(tx)
tx.Assert(tx.Get(a.pendingAnnouncePeers).(stmutil.Lenner).Len() == 0)
})
}))
}
func validNodeAddr(addr net.Addr) bool {
......@@ -136,12 +136,6 @@ func (a *Announce) shouldContact(addr krpc.NodeAddr, tx *stm.Tx) bool {
return true
}
func (a *Announce) completeContact() {
stm.Atomically(func(tx *stm.Tx) {
tx.Set(a.pending, tx.Get(a.pending).(int)-1)
})
}
func (a *Announce) responseNode(node krpc.NodeInfo) {
i := int160FromByteArray(node.ID)
stm.Atomically(a.pendContact(addrMaybeId{node.Addr, &i}))
......@@ -155,50 +149,42 @@ func (a *Announce) maybeAnnouncePeer(to Addr, token *string, peerId *krpc.ID) {
if !a.server.config.NoSecurity && (peerId == nil || !NodeIdSecure(*peerId, to.IP())) {
return
}
stm.Atomically(func(tx *stm.Tx) {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
tx.Set(a.pendingAnnouncePeers, tx.Get(a.pendingAnnouncePeers).(stmutil.List).Append(pendingAnnouncePeer{
Addr: to,
token: *token,
}))
})
}))
//a.server.announcePeer(to, a.infoHash, a.announcePort, *token, a.announcePortImplied, nil)
}
func (a *Announce) announcePeer(peer pendingAnnouncePeer, cteh *conntrack.EntryHandle) {
func (a *Announce) announcePeer(peer pendingAnnouncePeer) numWrites {
_, writes, _ := a.server.announcePeer(peer.Addr, a.infoHash, a.announcePort, peer.token, a.announcePortImplied)
finalizeCteh(cteh, writes)
return writes
}
func (a *Announce) beginAnnouncePeer(tx *stm.Tx) {
func (a *Announce) beginAnnouncePeer(tx *stm.Tx) interface{} {
l := tx.Get(a.pendingAnnouncePeers).(stmutil.List)
tx.Assert(l.Len() != 0)
x := l.Get(0).(pendingAnnouncePeer)
tx.Assert(a.server.sendLimit.AllowStm(tx))
cteh := a.server.config.ConnectionTracking.Allow(tx, a.server.connTrackEntryForAddr(x.Addr), "dht announce announce_peer", -1)
tx.Assert(cteh != nil)
tx.Set(a.pending, tx.Get(a.pending).(int)+1)
tx.Set(a.pendingAnnouncePeers, l.Slice(1, l.Len()))
tx.Return(txResT{run: func() {
a.announcePeer(x, cteh)
stm.Atomically(func(tx *stm.Tx) {
tx.Set(a.pending, tx.Get(a.pending).(int)-1)
})
}})
return a.beginQuery(x.Addr, "dht announce announce_peer", func() numWrites {
return a.announcePeer(x)
})(tx).(func())
}
func finalizeCteh(cteh *conntrack.EntryHandle, writes numWrites) {
if writes == 0 {
cteh.Forget()
panic("how to reverse rate limit?")
// TODO: panic("how to reverse rate limit?")
} else {
cteh.Done()
}
}
func (a *Announce) getPeers(addr Addr, cteh *conntrack.EntryHandle) {
func (a *Announce) getPeers(addr Addr) numWrites {
// log.Printf("sending get_peers to %v", node)
m, writes, err := a.server.getPeers(context.TODO(), addr, a.infoHash)
finalizeCteh(cteh, writes)
a.server.logger().Printf("Announce.server.getPeers result: m.Y=%v, numWrites=%v, err=%v", m.Y, writes, err)
// log.Printf("get_peers response error from %v: %v", node, err)
// Register suggested nodes closer to the target info-hash.
......@@ -218,7 +204,7 @@ func (a *Announce) getPeers(addr Addr, cteh *conntrack.EntryHandle) {
}
a.maybeAnnouncePeer(addr, m.R.Token, m.SenderID())
}
a.completeContact()
return writes
}
// Corresponds to the "values" key in a get_peers KRPC response. A list of
......@@ -238,14 +224,14 @@ func (a *Announce) close() {
a.cancel()
}
func (a *Announce) pendContact(node addrMaybeId) func(tx *stm.Tx) {
return func(tx *stm.Tx) {
func (a *Announce) pendContact(node addrMaybeId) stm.Operation {
return stm.VoidOperation(func(tx *stm.Tx) {
if !a.shouldContact(node.Addr, tx) {
// log.Printf("shouldn't contact (pend): %v", node)
return
}
a.traversal.pendContact(node)(tx)
}
})
}
type txResT struct {
......@@ -255,14 +241,15 @@ type txResT struct {
func (a *Announce) nodeContactor() {
for {
txRes := stm.Atomically(stm.Select(
func(tx *stm.Tx) {
tx.Assert(tx.Get(a.doneVar).(bool))
tx.Return(txResT{done: true})
},
a.beginGetPeers,
a.beginAnnouncePeer,
)).(txResT)
txRes := stm.Atomically(func(tx *stm.Tx) interface{} {
if tx.Get(a.doneVar).(bool) {
return txResT{done: true}
}
return txResT{run: stm.Select(
a.beginGetPeers,
a.beginAnnouncePeer,
)(tx).(func())}
}).(txResT)
if txRes.done {
break
}
......@@ -270,12 +257,21 @@ func (a *Announce) nodeContactor() {
}
}
func (a *Announce) beginGetPeers(tx *stm.Tx) {
func (a *Announce) beginGetPeers(tx *stm.Tx) interface{} {
addr := a.traversal.nextAddr(tx)
cteh := a.server.config.ConnectionTracking.Allow(tx, a.server.connTrackEntryForAddr(NewAddr(addr.UDP())), "announce get_peers", -1)
tx.Assert(cteh != nil)
tx.Assert(a.server.sendLimit.AllowStm(tx))
dhtAddr := NewAddr(addr.UDP())
tx.Set(a.numContacted, tx.Get(a.numContacted).(int)+1)
tx.Set(a.pending, tx.Get(a.pending).(int)+1)
tx.Return(txResT{run: func() { a.getPeers(NewAddr(addr.UDP()), cteh) }})
return a.beginQuery(dhtAddr, "dht announce get_peers", func() numWrites {
return a.getPeers(dhtAddr)
})(tx)
}
func (a *Announce) beginQuery(addr Addr, reason string, f func() numWrites) stm.Operation {
return func(tx *stm.Tx) interface{} {
tx.Set(a.pending, tx.Get(a.pending).(int)+1)
return a.server.beginQuery(addr, reason, func() numWrites {
defer stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { tx.Set(a.pending, tx.Get(a.pending).(int)-1) }))
return f()
})(tx)
}
}
......@@ -647,15 +647,15 @@ func (s *Server) connTrackEntryForAddr(a Addr) conntrack.Entry {
type numWrites int
func (s *Server) beginQuery(addr Addr, reason string, f func() numWrites) func(tx *stm.Tx) {
return func(tx *stm.Tx) {
func (s *Server) beginQuery(addr Addr, reason string, f func() numWrites) stm.Operation {
return func(tx *stm.Tx) interface{} {
tx.Assert(s.sendLimit.AllowStm(tx))
cteh := s.config.ConnectionTracking.Allow(tx, s.connTrackEntryForAddr(addr), reason, -1)
tx.Assert(cteh != nil)
tx.Return(func() {
return func() {
writes := f()
finalizeCteh(cteh, writes)
})
}
}
}
......
......@@ -40,8 +40,8 @@ func (t traversal) finished(tx *stm.Tx) {
tx.Assert(tx.Get(t.nodesPendingContact).(stmutil.Lenner).Len() == 0)
}
func (t traversal) pendContact(node addrMaybeId) func(*stm.Tx) {
return func(tx *stm.Tx) {
func (t traversal) pendContact(node addrMaybeId) stm.Operation {
return stm.VoidOperation(func(tx *stm.Tx) {
nodeAddrString := node.Addr.String()
if tx.Get(t.triedAddrs).(stmutil.Settish).Contains(nodeAddrString) {
return
......@@ -64,7 +64,7 @@ func (t traversal) pendContact(node addrMaybeId) func(*stm.Tx) {
tx.Set(t.addrBestIds, addrBestIds.Set(nodeAddrString, node.Id))
nodesPendingContact = nodesPendingContact.Add(node)
tx.Set(t.nodesPendingContact, nodesPendingContact)
}
})
}
func (a traversal) nextAddr(tx *stm.Tx) krpc.NodeAddr {
......
......@@ -12,22 +12,22 @@ import (
func TestTraversal(t *testing.T) {
var target int160
traversal := newTraversal(target)
assert.True(t, stm.WouldBlock(func(tx *stm.Tx) { traversal.nextAddr(tx) }))
assert.False(t, stm.WouldBlock(traversal.finished))
stm.Atomically(stm.Compose(func() (ret []func(tx *stm.Tx)) {
assert.True(t, stm.WouldBlock(stm.VoidOperation(func(tx *stm.Tx) { traversal.nextAddr(tx) })))
assert.False(t, stm.WouldBlock(stm.VoidOperation(traversal.finished)))
stm.Atomically(stm.Compose(func() (ret []stm.Operation) {
for _, v := range sampleAddrMaybeIds[2:6] {
ret = append(ret, traversal.pendContact(v))
}
return
}()...))
assert.False(t, stm.WouldBlock(func(tx *stm.Tx) { traversal.nextAddr(tx) }))
assert.True(t, stm.WouldBlock(traversal.finished))
pop := func(tx *stm.Tx) { tx.Return(traversal.nextAddr(tx)) }
assert.False(t, stm.WouldBlock(stm.VoidOperation(func(tx *stm.Tx) { traversal.nextAddr(tx) })))
assert.True(t, stm.WouldBlock(stm.VoidOperation(traversal.finished)))
pop := func(tx *stm.Tx) interface{} { return traversal.nextAddr(tx) }
var addrs []krpc.NodeAddr
for !stm.WouldBlock(pop) {
addrs = append(addrs, stm.Atomically(pop).(krpc.NodeAddr))
}
assert.False(t, stm.WouldBlock(traversal.finished))
assert.False(t, stm.WouldBlock(stm.VoidOperation(traversal.finished)))
t.Log(addrs)
assert.EqualValues(t, []krpc.NodeAddr{{Port: 1}, {}}, addrs)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment