server.go 22.8 KB
Newer Older
1 2 3
package dht

import (
Matt Joiner's avatar
Matt Joiner committed
4
	"context"
5
	"crypto/rand"
6 7 8 9
	"encoding/binary"
	"fmt"
	"io"
	"net"
Matt Joiner's avatar
Matt Joiner committed
10
	"text/tabwriter"
11 12
	"time"

Matt Joiner's avatar
Matt Joiner committed
13
	"github.com/anacrolix/log"
14
	"github.com/anacrolix/missinggo"
Matt Joiner's avatar
Matt Joiner committed
15
	"github.com/anacrolix/missinggo/v2/conntrack"
16
	"github.com/anacrolix/sync"
17 18 19
	"github.com/anacrolix/torrent/bencode"
	"github.com/anacrolix/torrent/iplist"
	"github.com/anacrolix/torrent/logonce"
20
	"github.com/anacrolix/torrent/metainfo"
21
	"github.com/pkg/errors"
Matt Joiner's avatar
Matt Joiner committed
22

Matt Joiner's avatar
Matt Joiner committed
23
	"github.com/anacrolix/stm"
Matt Joiner's avatar
Matt Joiner committed
24

Matt Joiner's avatar
Matt Joiner committed
25
	"github.com/anacrolix/dht/v2/krpc"
26 27
)

Matt Joiner's avatar
Matt Joiner committed
28 29 30 31 32 33 34
// A Server defines parameters for a DHT node server that is able to send
// queries, and respond to the ones from the network. Each node has a globally
// unique identifier known as the "node ID." Node IDs are chosen at random
// from the same 160-bit space as BitTorrent infohashes and define the
// behaviour of the node. Zero valued Server does not have a valid ID and thus
// is unable to function properly. Use `NewServer(nil)` to initialize a
// default node.
35
type Server struct {
Matt Joiner's avatar
Matt Joiner committed
36 37 38
	id          int160
	socket      net.PacketConn
	resendDelay func() time.Duration
39

40
	mu           sync.RWMutex
41 42 43 44 45 46 47 48
	transactions map[transactionKey]*Transaction
	nextT        uint64 // unique "t" field for outbound queries
	table        table
	closed       missinggo.Event
	ipBlockList  iplist.Ranger
	tokenServer  tokenServer // Manages tokens we issue to our queriers.
	config       ServerConfig
	stats        ServerStats
49 50 51
	sendLimit    interface {
		Wait(ctx context.Context) error
		Allow() bool
Matt Joiner's avatar
Matt Joiner committed
52
		AllowStm(tx *stm.Tx) bool
53
	}
54 55
}

Matt Joiner's avatar
Matt Joiner committed
56 57
func (s *Server) numGoodNodes() (num int) {
	s.table.forNodes(func(n *node) bool {
58
		if n.IsGood() {
Matt Joiner's avatar
Matt Joiner committed
59 60 61 62 63 64 65
			num++
		}
		return true
	})
	return
}

Matt Joiner's avatar
Matt Joiner committed
66 67 68 69
func prettySince(t time.Time) string {
	if t.IsZero() {
		return "never"
	}
Matt Joiner's avatar
Matt Joiner committed
70 71 72
	d := time.Since(t)
	d /= time.Second
	d *= time.Second
73
	return fmt.Sprintf("%s ago", d)
Matt Joiner's avatar
Matt Joiner committed
74 75 76 77 78 79 80 81
}

func (s *Server) WriteStatus(w io.Writer) {
	fmt.Fprintf(w, "Listening on %s\n", s.Addr())
	s.mu.Lock()
	defer s.mu.Unlock()
	fmt.Fprintf(w, "Nodes in table: %d good, %d total\n", s.numGoodNodes(), s.numNodes())
	fmt.Fprintf(w, "Ongoing transactions: %d\n", len(s.transactions))
Matt Joiner's avatar
Matt Joiner committed
82
	fmt.Fprintf(w, "Server node ID: %x\n", s.id.Bytes())
Matt Joiner's avatar
Matt Joiner committed
83 84
	fmt.Fprintln(w)
	tw := tabwriter.NewWriter(w, 0, 0, 1, ' ', 0)
85
	fmt.Fprintf(tw, "b#\tnode id\taddr\tanntok\tlast query\tlast response\tcf\n")
Matt Joiner's avatar
Matt Joiner committed
86 87
	for i, b := range s.table.buckets {
		b.EachNode(func(n *node) bool {
Matt Joiner's avatar
Matt Joiner committed
88
			fmt.Fprintf(tw, "%d\t%x\t%s\t%v\t%s\t%s\t%d\n",
Matt Joiner's avatar
Matt Joiner committed
89
				i,
Matt Joiner's avatar
Matt Joiner committed
90
				n.id.Bytes(),
Matt Joiner's avatar
Matt Joiner committed
91
				n.addr,
92 93 94 95 96 97
				func() int {
					if n.announceToken == nil {
						return -1
					}
					return len(*n.announceToken)
				}(),
Matt Joiner's avatar
Matt Joiner committed
98 99 100 101 102 103 104 105
				prettySince(n.lastGotQuery),
				prettySince(n.lastGotResponse),
				n.consecutiveFailures,
			)
			return true
		})
	}
	tw.Flush()
106
	fmt.Fprintln(w)
Matt Joiner's avatar
Matt Joiner committed
107 108
}

Matt Joiner's avatar
Matt Joiner committed
109 110 111 112 113 114 115 116
func (s *Server) numNodes() (num int) {
	s.table.forNodes(func(n *node) bool {
		num++
		return true
	})
	return
}

117
// Stats returns statistics for the server.
118
func (s *Server) Stats() ServerStats {
119 120
	s.mu.Lock()
	defer s.mu.Unlock()
121
	ss := s.stats
Matt Joiner's avatar
Matt Joiner committed
122 123
	ss.GoodNodes = s.numGoodNodes()
	ss.Nodes = s.numNodes()
124
	ss.OutstandingTransactions = len(s.transactions)
125
	return ss
126 127 128 129 130 131 132 133
}

// Addr returns the listen address for the server. Packets arriving to this address
// are processed by the server (unless aliens are involved).
func (s *Server) Addr() net.Addr {
	return s.socket.LocalAddr()
}

134 135 136 137 138 139 140 141 142
func NewDefaultServerConfig() *ServerConfig {
	return &ServerConfig{
		Conn:               mustListen(":0"),
		NoSecurity:         true,
		StartingNodes:      GlobalBootstrapAddrs,
		ConnectionTracking: conntrack.NewInstance(),
	}
}

143 144 145
// NewServer initializes a new DHT node server.
func NewServer(c *ServerConfig) (s *Server, err error) {
	if c == nil {
146
		c = NewDefaultServerConfig()
147
	}
148 149 150
	if c.Conn == nil {
		return nil, errors.New("non-nil Conn required")
	}
Matt Joiner's avatar
Matt Joiner committed
151 152
	if missinggo.IsZeroValue(c.NodeId) {
		c.NodeId = RandomNodeID()
153
		if !c.NoSecurity && c.PublicIP != nil {
154
			SecureNodeId(&c.NodeId, c.PublicIP)
Matt Joiner's avatar
Matt Joiner committed
155 156
		}
	}
Matt Joiner's avatar
Matt Joiner committed
157 158
	// If Logger is empty, emulate the old behaviour: Everything is logged to the default location,
	// and there are no debug messages.
Matt Joiner's avatar
Matt Joiner committed
159 160 161 162 163
	if c.Logger.LoggerImpl == nil {
		c.Logger = log.Default.WithFilter(func(m log.Msg) bool {
			return !m.HasValue(log.Debug)
		})
	}
Matt Joiner's avatar
Matt Joiner committed
164 165 166 167 168 169 170 171 172
	// Add log.Debug by default.
	c.Logger = c.Logger.WithMap(func(m log.Msg) log.Msg {
		var l log.Level
		if m.GetValueByType(&l) {
			return m
		}
		return m.WithValues(log.Debug)
	})

173 174 175
	s = &Server{
		config:      *c,
		ipBlockList: c.IPBlocklist,
176 177 178 179 180
		tokenServer: tokenServer{
			maxIntervalDelta: 2,
			interval:         5 * time.Minute,
			secret:           make([]byte, 20),
		},
Matt Joiner's avatar
Matt Joiner committed
181
		transactions: make(map[transactionKey]*Transaction),
182 183 184
		table: table{
			k: 8,
		},
Matt Joiner's avatar
Matt Joiner committed
185
		sendLimit: defaultSendLimiter,
186
	}
Matt Joiner's avatar
Matt Joiner committed
187 188 189
	if s.config.ConnectionTracking == nil {
		s.config.ConnectionTracking = conntrack.NewInstance()
	}
190
	rand.Read(s.tokenServer.secret)
Matt Joiner's avatar
Matt Joiner committed
191 192 193
	s.socket = c.Conn
	s.id = int160FromByteArray(c.NodeId)
	s.table.rootID = s.id
Matt Joiner's avatar
Matt Joiner committed
194 195 196 197
	s.resendDelay = s.config.QueryResendDelay
	if s.resendDelay == nil {
		s.resendDelay = defaultQueryResendDelay
	}
198
	go s.serveUntilClosed()
199 200 201
	return
}

202 203 204 205 206 207 208 209 210 211 212 213
func (s *Server) serveUntilClosed() {
	err := s.serve()
	s.mu.Lock()
	defer s.mu.Unlock()
	if s.closed.IsSet() {
		return
	}
	if err != nil {
		panic(err)
	}
}

Matt Joiner's avatar
Matt Joiner committed
214
// Returns a description of the Server.
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
func (s *Server) String() string {
	return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
}

// Packets to and from any address matching a range in the list are dropped.
func (s *Server) SetIPBlockList(list iplist.Ranger) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.ipBlockList = list
}

func (s *Server) IPBlocklist() iplist.Ranger {
	return s.ipBlockList
}

Matt Joiner's avatar
Matt Joiner committed
230
func (s *Server) processPacket(b []byte, addr Addr) {
231
	if len(b) < 2 || b[0] != 'd' {
232
		// KRPC messages are bencoded dicts.
233
		readNotKRPCDict.Add(1)
234 235
		return
	}
236
	var d krpc.Msg
237
	err := bencode.Unmarshal(b, &d)
238 239 240
	if _, ok := err.(bencode.ErrUnusedTrailingBytes); ok {
		// log.Printf("%s: received message packet with %d trailing bytes: %q", s, _err.NumUnusedBytes, b[len(b)-_err.NumUnusedBytes:])
		expvars.Add("processed packets with trailing bytes", 1)
241
	} else if err != nil {
242
		readUnmarshalError.Add(1)
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
		func() {
			if se, ok := err.(*bencode.SyntaxError); ok {
				// The message was truncated.
				if int(se.Offset) == len(b) {
					return
				}
				// Some messages seem to drop to nul chars abrubtly.
				if int(se.Offset) < len(b) && b[se.Offset] == 0 {
					return
				}
				// The message isn't bencode from the first.
				if se.Offset == 0 {
					return
				}
			}
			// if missinggo.CryHeard() {
			// 	log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b)
			// }
		}()
		return
	}
	s.mu.Lock()
	defer s.mu.Unlock()
266 267 268
	if s.closed.IsSet() {
		return
	}
269 270 271 272 273 274 275
	var n *node
	if sid := d.SenderID(); sid != nil {
		n, _ = s.getNode(addr, int160FromByteArray(*sid), !d.ReadOnly)
		if n != nil && d.ReadOnly {
			n.readOnly = true
		}
	}
276
	if d.Y == "q" {
Matt Joiner's avatar
Matt Joiner committed
277
		expvars.Add("received queries", 1)
278
		s.logger().Printf("received query %q from %v", d.Q, addr)
279 280 281
		s.handleQuery(addr, d)
		return
	}
Matt Joiner's avatar
Matt Joiner committed
282 283 284 285 286 287
	tk := transactionKey{
		RemoteAddr: addr.String(),
		T:          d.T,
	}
	t, ok := s.transactions[tk]
	if !ok {
288
		s.logger().Printf("received response for untracked transaction %q from %v", d.T, addr)
289 290
		return
	}
291
	s.logger().Printf("received response for transaction %q from %v", d.T, addr)
292
	go t.handleResponse(d)
293 294 295
	if n != nil {
		n.lastGotResponse = time.Now()
		n.consecutiveFailures = 0
296
	}
Matt Joiner's avatar
Matt Joiner committed
297 298
	// Ensure we don't send more than one response.
	s.deleteTransaction(tk)
299 300 301 302 303 304 305 306 307
}

func (s *Server) serve() error {
	var b [0x10000]byte
	for {
		n, addr, err := s.socket.ReadFrom(b[:])
		if err != nil {
			return err
		}
Matt Joiner's avatar
Matt Joiner committed
308
		expvars.Add("packets read", 1)
309 310 311 312
		if n == len(b) {
			logonce.Stderr.Printf("received dht packet exceeds buffer size")
			continue
		}
313
		if missinggo.AddrPort(addr) == 0 {
314
			readZeroPort.Add(1)
315 316
			continue
		}
317 318 319 320
		s.mu.Lock()
		blocked := s.ipBlocked(missinggo.AddrIP(addr))
		s.mu.Unlock()
		if blocked {
321
			readBlocked.Add(1)
322 323
			continue
		}
Matt Joiner's avatar
Matt Joiner committed
324
		s.processPacket(b[:n], NewAddr(addr))
325 326 327 328 329 330 331 332 333 334 335 336
	}
}

func (s *Server) ipBlocked(ip net.IP) (blocked bool) {
	if s.ipBlockList == nil {
		return
	}
	_, blocked = s.ipBlockList.Lookup(ip)
	return
}

// Adds directly to the node table.
337
func (s *Server) AddNode(ni krpc.NodeInfo) error {
338 339
	id := int160FromByteArray(ni.ID)
	if id.IsZero() {
340
		return s.Ping(ni.Addr.UDP(), nil)
341
	}
342
	_, err := s.getNode(NewAddr(ni.Addr.UDP()), int160FromByteArray(ni.ID), true)
343
	return err
344 345
}

346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
func wantsContain(ws []krpc.Want, w krpc.Want) bool {
	for _, _w := range ws {
		if _w == w {
			return true
		}
	}
	return false
}

func shouldReturnNodes(queryWants []krpc.Want, querySource net.IP) bool {
	if len(queryWants) != 0 {
		return wantsContain(queryWants, krpc.WantNodes)
	}
	return querySource.To4() != nil
}

func shouldReturnNodes6(queryWants []krpc.Want, querySource net.IP) bool {
	if len(queryWants) != 0 {
		return wantsContain(queryWants, krpc.WantNodes6)
	}
	return querySource.To4() == nil
}

func (s *Server) makeReturnNodes(target int160, filter func(krpc.NodeAddr) bool) []krpc.NodeInfo {
	return s.closestGoodNodeInfos(8, target, filter)
}

373 374 375 376 377 378 379 380 381
var krpcErrMissingArguments = krpc.Error{
	Code: krpc.ErrorCodeProtocolError,
	Msg:  "missing arguments dict",
}

func (s *Server) setReturnNodes(r *krpc.Return, queryMsg krpc.Msg, querySource Addr) *krpc.Error {
	if queryMsg.A == nil {
		return &krpcErrMissingArguments
	}
382
	target := int160FromByteArray(queryMsg.A.InfoHash)
Matt Joiner's avatar
Matt Joiner committed
383
	if shouldReturnNodes(queryMsg.A.Want, querySource.IP()) {
384 385
		r.Nodes = s.makeReturnNodes(target, func(na krpc.NodeAddr) bool { return na.IP.To4() != nil })
	}
Matt Joiner's avatar
Matt Joiner committed
386
	if shouldReturnNodes6(queryMsg.A.Want, querySource.IP()) {
387 388
		r.Nodes6 = s.makeReturnNodes(target, func(krpc.NodeAddr) bool { return true })
	}
389
	return nil
390 391
}

392 393
// TODO: Probably should write error messages back to senders if something is
// wrong.
394
func (s *Server) handleQuery(source Addr, m krpc.Msg) {
395
	go func() {
Matt Joiner's avatar
Matt Joiner committed
396
		expvars.Add(fmt.Sprintf("received query %q", m.Q), 1)
397 398 399 400 401 402 403 404 405
		if a := m.A; a != nil {
			if a.NoSeed != 0 {
				expvars.Add("received argument noseed", 1)
			}
			if a.Scrape != 0 {
				expvars.Add("received argument scrape", 1)
			}
		}
	}()
406 407 408 409
	if m.SenderID() != nil {
		if n, _ := s.getNode(source, int160FromByteArray(*m.SenderID()), !m.ReadOnly); n != nil {
			n.lastGotQuery = time.Now()
		}
410
	}
411
	if s.config.OnQuery != nil {
Matt Joiner's avatar
Matt Joiner committed
412
		propagate := s.config.OnQuery(&m, source.Raw())
413 414 415 416
		if !propagate {
			return
		}
	}
417 418 419 420
	// Don't respond.
	if s.config.Passive {
		return
	}
421
	// TODO: Should we disallow replying to ourself?
422 423 424
	args := m.A
	switch m.Q {
	case "ping":
425
		s.reply(source, m.T, krpc.Return{})
Matt Joiner's avatar
Matt Joiner committed
426
	case "get_peers":
427
		var r krpc.Return
428 429 430 431 432
		// TODO: Return values.
		if err := s.setReturnNodes(&r, m, source); err != nil {
			s.sendError(source, m.T, *err)
			break
		}
433 434 435 436
		r.Token = func() *string {
			t := s.createToken(source)
			return &t
		}()
437 438 439
		s.reply(source, m.T, r)
	case "find_node":
		var r krpc.Return
440 441 442 443
		if err := s.setReturnNodes(&r, m, source); err != nil {
			s.sendError(source, m.T, *err)
			break
		}
444
		s.reply(source, m.T, r)
445
	case "announce_peer":
446
		readAnnouncePeer.Add(1)
447
		if !s.validToken(args.Token, source) {
448
			expvars.Add("received announce_peer with invalid token", 1)
449 450
			return
		}
Matt Joiner's avatar
Matt Joiner committed
451
		expvars.Add("received announce_peer with valid token", 1)
452
		if h := s.config.OnAnnouncePeer; h != nil {
Matt Joiner's avatar
Matt Joiner committed
453 454 455 456 457
			var port int
			portOk := false
			if args.Port != nil {
				port = *args.Port
				portOk = true
458
			}
Matt Joiner's avatar
Matt Joiner committed
459
			if args.ImpliedPort {
Matt Joiner's avatar
Matt Joiner committed
460 461
				port = source.Port()
				portOk = true
462
			}
Matt Joiner's avatar
Matt Joiner committed
463
			go h(metainfo.Hash(args.InfoHash), source.IP(), port, portOk)
464
		}
465
		s.reply(source, m.T, krpc.Return{})
466
	default:
467 468 469 470
		s.sendError(source, m.T, krpc.ErrorMethodUnknown)
	}
}

471
func (s *Server) sendError(addr Addr, t string, e krpc.Error) {
472 473 474 475 476 477 478 479 480
	m := krpc.Msg{
		T: t,
		Y: "e",
		E: &e,
	}
	b, err := bencode.Marshal(m)
	if err != nil {
		panic(err)
	}
Matt Joiner's avatar
Matt Joiner committed
481
	s.logger().Printf("sending error to %q: %v", addr, e)
Matt Joiner's avatar
Matt Joiner committed
482
	_, err = s.writeToNode(context.Background(), b, addr, false, true)
483
	if err != nil {
Matt Joiner's avatar
Matt Joiner committed
484
		s.logger().Printf("error replying to %q: %v", addr, err)
485 486 487
	}
}

488
func (s *Server) reply(addr Addr, t string, r krpc.Return) {
Matt Joiner's avatar
Matt Joiner committed
489
	r.ID = s.id.AsByteArray()
490
	m := krpc.Msg{
491 492 493 494
		T:  t,
		Y:  "r",
		R:  &r,
		IP: addr.KRPC(),
495 496 497 498 499
	}
	b, err := bencode.Marshal(m)
	if err != nil {
		panic(err)
	}
Matt Joiner's avatar
Matt Joiner committed
500
	log.Fmsg("replying to %q", addr).Log(s.logger())
Matt Joiner's avatar
Matt Joiner committed
501
	wrote, err := s.writeToNode(context.Background(), b, addr, false, true)
502
	if err != nil {
Matt Joiner's avatar
Matt Joiner committed
503
		s.config.Logger.Printf("error replying to %s: %s", addr, err)
504
	}
Matt Joiner's avatar
Matt Joiner committed
505 506 507
	if wrote {
		expvars.Add("replied to peer", 1)
	}
508 509
}

510
// Returns the node if it's in the routing table, adding it if appropriate.
511
func (s *Server) getNode(addr Addr, id int160, tryAdd bool) (*node, error) {
512
	if n := s.table.getNode(addr, id); n != nil {
513
		return n, nil
514
	}
515
	n := &node{nodeKey: nodeKey{
Matt Joiner's avatar
Matt Joiner committed
516
		id:   id,
517
		addr: addr,
518
	}}
519 520 521
	// Check that the node would be good to begin with. (It might have a bad
	// ID or banned address, or we fucked up the initial node field
	// invariant.)
522 523 524
	if err := s.nodeErr(n); err != nil {
		return nil, err
	}
525 526 527
	if !tryAdd {
		return nil, errors.New("node not present and add flag false")
	}
528 529 530 531 532 533 534 535 536 537 538 539 540
	b := s.table.bucketForID(id)
	if b.Len() >= s.table.k {
		if b.EachNode(func(n *node) bool {
			if s.nodeIsBad(n) {
				s.table.dropNode(n)
			}
			return b.Len() >= s.table.k
		}) {
			return nil, errors.New("no room in bucket")
		}
	}
	if err := s.table.addNode(n); err != nil {
		panic(fmt.Sprintf("expected to add node: %s", err))
541
	}
542
	return n, nil
Matt Joiner's avatar
Matt Joiner committed
543 544
}

545 546 547 548 549 550 551
func (s *Server) nodeIsBad(n *node) bool {
	return s.nodeErr(n) != nil
}

func (s *Server) nodeErr(n *node) error {
	if n.id == s.id {
		return errors.New("is self")
Matt Joiner's avatar
Matt Joiner committed
552
	}
553
	if n.id.IsZero() {
554
		return errors.New("has zero id")
555 556
	}
	if !s.config.NoSecurity && !n.IsSecure() {
557 558 559 560 561 562 563
		return errors.New("not secure")
	}
	if n.IsGood() {
		return nil
	}
	if n.consecutiveFailures >= 3 {
		return fmt.Errorf("has %d consecutive failures", n.consecutiveFailures)
564
	}
Matt Joiner's avatar
Matt Joiner committed
565
	return nil
566 567
}

Matt Joiner's avatar
Matt Joiner committed
568
func (s *Server) writeToNode(ctx context.Context, b []byte, node Addr, wait, rate bool) (wrote bool, err error) {
569
	if list := s.ipBlockList; list != nil {
Matt Joiner's avatar
Matt Joiner committed
570
		if r, ok := list.Lookup(node.IP()); ok {
Matt Joiner's avatar
Matt Joiner committed
571
			err = fmt.Errorf("write to %v blocked by %v", node, r)
572 573 574
			return
		}
	}
Matt Joiner's avatar
Matt Joiner committed
575
	//s.config.Logger.WithValues(log.Debug).Printf("writing to %s: %q", node.String(), b)
Matt Joiner's avatar
Matt Joiner committed
576 577 578 579 580 581 582 583 584 585
	if rate {
		if wait {
			err = s.sendLimit.Wait(ctx)
			if err != nil {
				return false, err
			}
		} else {
			if !s.sendLimit.Allow() {
				return false, errors.New("rate limit exceeded")
			}
Matt Joiner's avatar
Matt Joiner committed
586 587
		}
	}
Matt Joiner's avatar
Matt Joiner committed
588
	n, err := s.socket.WriteTo(b, node.Raw())
589
	writes.Add(1)
Matt Joiner's avatar
Matt Joiner committed
590 591 592 593 594
	if rate {
		expvars.Add("rated writes", 1)
	} else {
		expvars.Add("unrated writes", 1)
	}
595
	if err != nil {
596
		writeErrors.Add(1)
Matt Joiner's avatar
Matt Joiner committed
597
		err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err)
598 599
		return
	}
600
	wrote = true
601 602 603 604 605 606 607 608 609
	if n != len(b) {
		err = io.ErrShortWrite
		return
	}
	return
}

func (s *Server) nextTransactionID() string {
	var b [binary.MaxVarintLen64]byte
610 611
	n := binary.PutUvarint(b[:], s.nextT)
	s.nextT++
612 613 614
	return string(b[:n])
}

Matt Joiner's avatar
Matt Joiner committed
615 616
func (s *Server) deleteTransaction(k transactionKey) {
	delete(s.transactions, k)
617 618
}

Matt Joiner's avatar
Matt Joiner committed
619 620
func (s *Server) addTransaction(k transactionKey, t *Transaction) {
	if _, ok := s.transactions[k]; ok {
621 622
		panic("transaction not unique")
	}
Matt Joiner's avatar
Matt Joiner committed
623
	s.transactions[k] = t
624 625 626 627
}

// ID returns the 20-byte server ID. This is the ID used to communicate with the
// DHT network.
Matt Joiner's avatar
Matt Joiner committed
628 629
func (s *Server) ID() [20]byte {
	return s.id.AsByteArray()
630 631
}

632 633 634 635 636 637 638 639
func (s *Server) createToken(addr Addr) string {
	return s.tokenServer.CreateToken(addr)
}

func (s *Server) validToken(token string, addr Addr) bool {
	return s.tokenServer.ValidToken(token, addr)
}

Matt Joiner's avatar
Matt Joiner committed
640 641 642 643 644 645 646 647
func (s *Server) connTrackEntryForAddr(a Addr) conntrack.Entry {
	return conntrack.Entry{
		s.socket.LocalAddr().Network(),
		s.socket.LocalAddr().String(),
		a.String(),
	}
}

648 649
type numWrites int

650 651
func (s *Server) beginQuery(addr Addr, reason string, f func() numWrites) stm.Operation {
	return func(tx *stm.Tx) interface{} {
652 653 654
		tx.Assert(s.sendLimit.AllowStm(tx))
		cteh := s.config.ConnectionTracking.Allow(tx, s.connTrackEntryForAddr(addr), reason, -1)
		tx.Assert(cteh != nil)
655
		return func() {
656 657
			writes := f()
			finalizeCteh(cteh, writes)
658
		}
659 660 661
	}
}

662
func (s *Server) query(addr Addr, q string, a *krpc.MsgArgs, callback func(krpc.Msg, error)) error {
Matt Joiner's avatar
Matt Joiner committed
663 664 665 666
	if callback == nil {
		callback = func(krpc.Msg, error) {}
	}
	go func() {
667 668 669 670 671 672 673 674 675
		stm.Atomically(
			s.beginQuery(addr, fmt.Sprintf("send dht query %q", q),
				func() numWrites {
					m, writes, err := s.queryContext(context.Background(), addr, q, a)
					callback(m, err)
					return writes
				},
			),
		).(func())()
Matt Joiner's avatar
Matt Joiner committed
676 677 678 679
	}()
	return nil
}

Matt Joiner's avatar
Matt Joiner committed
680
func (s *Server) makeQueryBytes(q string, a *krpc.MsgArgs, t string) []byte {
681
	if a == nil {
Matt Joiner's avatar
Matt Joiner committed
682
		a = &krpc.MsgArgs{}
683
	}
Matt Joiner's avatar
Matt Joiner committed
684 685
	a.ID = s.ID()
	m := krpc.Msg{
Matt Joiner's avatar
Matt Joiner committed
686
		T: t,
Matt Joiner's avatar
Matt Joiner committed
687 688 689
		Y: "q",
		Q: q,
		A: a,
690
	}
Matt Joiner's avatar
Matt Joiner committed
691 692
	// BEP 43. Outgoing queries from passive nodes should contain "ro":1 in the top level
	// dictionary.
693
	if s.config.Passive {
Matt Joiner's avatar
Matt Joiner committed
694
		m.ReadOnly = true
695
	}
Matt Joiner's avatar
Matt Joiner committed
696
	b, err := bencode.Marshal(m)
697
	if err != nil {
Matt Joiner's avatar
Matt Joiner committed
698
		panic(err)
699
	}
Matt Joiner's avatar
Matt Joiner committed
700 701 702
	return b
}

703
func (s *Server) queryContext(ctx context.Context, addr Addr, q string, a *krpc.MsgArgs) (reply krpc.Msg, writes numWrites, err error) {
Matt Joiner's avatar
Matt Joiner committed
704
	defer func(started time.Time) {
705 706 707
		s.logger().WithValues(log.Debug, q).Printf(
			"queryContext(%v) returned after %v (err=%v, reply.Y=%v, reply.E=%v, writes=%v)",
			q, time.Since(started), err, reply.Y, reply.E, writes)
Matt Joiner's avatar
Matt Joiner committed
708
	}(time.Now())
Matt Joiner's avatar
Matt Joiner committed
709
	replyChan := make(chan krpc.Msg, 1)
Matt Joiner's avatar
Matt Joiner committed
710
	t := &Transaction{
Matt Joiner's avatar
Matt Joiner committed
711
		onResponse: func(m krpc.Msg) {
Matt Joiner's avatar
Matt Joiner committed
712
			replyChan <- m
Matt Joiner's avatar
Matt Joiner committed
713
		},
714
	}
Matt Joiner's avatar
Matt Joiner committed
715 716 717 718 719
	tk := transactionKey{
		RemoteAddr: addr.String(),
	}
	s.mu.Lock()
	tid := s.nextTransactionID()
720
	s.stats.OutboundQueriesAttempted++
Matt Joiner's avatar
Matt Joiner committed
721 722 723 724 725 726 727 728
	tk.T = tid
	s.addTransaction(tk, t)
	s.mu.Unlock()
	sendErr := make(chan error, 1)
	sendCtx, cancelSend := context.WithCancel(ctx)
	defer cancelSend()
	go s.transactionQuerySender(sendCtx, sendErr, s.makeQueryBytes(q, a, tid), &writes, addr)
	expvars.Add(fmt.Sprintf("outbound %s queries", q), 1)
Matt Joiner's avatar
Matt Joiner committed
729 730 731 732
	select {
	case reply = <-replyChan:
	case <-ctx.Done():
		err = ctx.Err()
Matt Joiner's avatar
Matt Joiner committed
733 734 735 736 737 738 739 740 741 742 743 744 745
	case err = <-sendErr:
	}
	s.mu.Lock()
	s.deleteTransaction(tk)
	if err != nil {
		for _, n := range s.table.addrNodes(addr) {
			n.consecutiveFailures++
		}
	}
	s.mu.Unlock()
	return
}

746
func (s *Server) transactionQuerySender(sendCtx context.Context, sendErr chan<- error, b []byte, writes *numWrites, addr Addr) {
Matt Joiner's avatar
Matt Joiner committed
747 748 749 750
	defer close(sendErr)
	err := transactionSender(
		sendCtx,
		func() error {
Matt Joiner's avatar
Matt Joiner committed
751
			wrote, err := s.writeToNode(sendCtx, b, addr, *writes == 0, *writes != 0)
Matt Joiner's avatar
Matt Joiner committed
752
			if wrote {
Matt Joiner's avatar
Matt Joiner committed
753
				*writes++
Matt Joiner's avatar
Matt Joiner committed
754 755 756 757 758 759 760 761
			}
			return err
		},
		s.resendDelay,
		maxTransactionSends,
	)
	if err != nil {
		sendErr <- err
Matt Joiner's avatar
Matt Joiner committed
762 763
		return
	}
Matt Joiner's avatar
Matt Joiner committed
764 765 766 767 768 769 770
	select {
	case <-sendCtx.Done():
		sendErr <- sendCtx.Err()
	case <-time.After(s.resendDelay()):
		sendErr <- errors.New("timed out")
	}

771 772 773
}

// Sends a ping query to the address given.
Matt Joiner's avatar
Matt Joiner committed
774
func (s *Server) Ping(node *net.UDPAddr, callback func(krpc.Msg, error)) error {
775 776
	s.mu.Lock()
	defer s.mu.Unlock()
Matt Joiner's avatar
Matt Joiner committed
777 778 779 780
	return s.ping(node, callback)
}

func (s *Server) ping(node *net.UDPAddr, callback func(krpc.Msg, error)) error {
Matt Joiner's avatar
Matt Joiner committed
781
	return s.query(NewAddr(node), "ping", nil, callback)
782 783
}

784
func (s *Server) announcePeer(node Addr, infoHash int160, port int, token string, impliedPort bool) (m krpc.Msg, writes numWrites, err error) {
785
	if port == 0 && !impliedPort {
786 787
		err = errors.New("no port specified")
		return
788
	}
789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808
	m, writes, err = s.queryContext(
		context.TODO(), node, "announce_peer",
		&krpc.MsgArgs{
			ImpliedPort: impliedPort,
			InfoHash:    infoHash.AsByteArray(),
			Port:        &port,
			Token:       token,
		},
	)
	if err != nil {
		return
	}
	if err = m.Error(); err != nil {
		announceErrors.Add(1)
		return
	}
	s.mu.Lock()
	defer s.mu.Unlock()
	s.stats.SuccessfulOutboundAnnouncePeerQueries++
	return
809 810 811
}

// Add response nodes to node table.
812 813
func (s *Server) addResponseNodes(d krpc.Msg) {
	if d.R == nil {
814 815
		return
	}
816 817 818
	d.R.ForAllNodes(func(ni krpc.NodeInfo) {
		s.getNode(NewAddr(ni.Addr.UDP()), int160FromByteArray(ni.ID), true)
	})
819 820 821
}

// Sends a find_node query to addr. targetID is the node we're looking for.
822 823 824
func (s *Server) findNode(addr Addr, targetID int160, callback func(krpc.Msg, error)) (err error) {
	return s.query(addr, "find_node", &krpc.MsgArgs{
		Target: targetID.AsByteArray(),
825
		Want:   []krpc.Want{krpc.WantNodes, krpc.WantNodes6},
826
	}, func(m krpc.Msg, err error) {
827 828
		// Scrape peers from the response to put in the server's table before
		// handing the response back to the caller.
829 830 831 832
		s.mu.Lock()
		s.addResponseNodes(m)
		s.mu.Unlock()
		callback(m, err)
833 834 835 836 837 838 839
	})
}

// Returns how many nodes are in the node table.
func (s *Server) NumNodes() int {
	s.mu.Lock()
	defer s.mu.Unlock()
Matt Joiner's avatar
Matt Joiner committed
840
	return s.numNodes()
841 842 843
}

// Exports the current node table.
844
func (s *Server) Nodes() (nis []krpc.NodeInfo) {
845 846
	s.mu.Lock()
	defer s.mu.Unlock()
Matt Joiner's avatar
Matt Joiner committed
847 848
	s.table.forNodes(func(n *node) bool {
		nis = append(nis, krpc.NodeInfo{
849
			Addr: n.addr.KRPC(),
Matt Joiner's avatar
Matt Joiner committed
850 851 852 853
			ID:   n.id.AsByteArray(),
		})
		return true
	})
854 855 856 857 858 859
	return
}

// Stops the server network activity. This is all that's required to clean-up a Server.
func (s *Server) Close() {
	s.mu.Lock()
860 861 862
	defer s.mu.Unlock()
	s.closed.Set()
	s.socket.Close()
863 864
}

865
func (s *Server) getPeers(ctx context.Context, addr Addr, infoHash int160) (krpc.Msg, numWrites, error) {
Matt Joiner's avatar
Matt Joiner committed
866
	m, writes, err := s.queryContext(ctx, addr, "get_peers", &krpc.MsgArgs{
867
		InfoHash: infoHash.AsByteArray(),
868 869
		// TODO: Maybe IPv4-only Servers won't want IPv6 nodes?
		Want: []krpc.Want{krpc.WantNodes, krpc.WantNodes6},
870 871 872 873 874 875 876 877 878 879 880 881 882 883 884
	})
	s.mu.Lock()
	defer s.mu.Unlock()
	s.addResponseNodes(m)
	if m.R != nil {
		if m.R.Token == nil {
			expvars.Add("get_peers responses with no token", 1)
		} else if len(*m.R.Token) == 0 {
			expvars.Add("get_peers responses with empty token", 1)
		} else {
			expvars.Add("get_peers responses with token", 1)
		}
		if m.SenderID() != nil && m.R.Token != nil {
			if n, _ := s.getNode(addr, int160FromByteArray(*m.SenderID()), false); n != nil {
				n.announceToken = m.R.Token
885
			}
886
		}
887
	}
Matt Joiner's avatar
Matt Joiner committed
888
	return m, writes, err
Matt Joiner's avatar
Matt Joiner committed
889 890
}

891 892 893 894 895 896 897 898 899 900
func (s *Server) closestGoodNodeInfos(
	k int,
	targetID int160,
	filter func(krpc.NodeAddr) bool,
) (
	ret []krpc.NodeInfo,
) {
	for _, n := range s.closestNodes(k, targetID, func(n *node) bool {
		return n.IsGood() && filter(n.NodeInfo().Addr)
	}) {
Matt Joiner's avatar
Matt Joiner committed
901 902
		ret = append(ret, n.NodeInfo())
	}
903 904 905
	return
}

Matt Joiner's avatar
Matt Joiner committed
906 907
func (s *Server) closestNodes(k int, target int160, filter func(*node) bool) []*node {
	return s.table.closestNodes(k, target, filter)
908 909
}

910
func (s *Server) traversalStartingNodes() (nodes []addrMaybeId, err error) {
911
	s.mu.RLock()
Matt Joiner's avatar
Matt Joiner committed
912
	s.table.forNodes(func(n *node) bool {
913
		nodes = append(nodes, addrMaybeId{n.addr.KRPC(), &n.id})
Matt Joiner's avatar
Matt Joiner committed
914 915
		return true
	})
916
	s.mu.RUnlock()
917
	if len(nodes) > 0 {
Matt Joiner's avatar
Matt Joiner committed
918 919
		return
	}
920
	if s.config.StartingNodes != nil {
921
		addrs, err := s.config.StartingNodes()
922
		if err != nil {
923 924 925 926
			return nil, errors.Wrap(err, "getting starting nodes")
		}
		for _, a := range addrs {
			nodes = append(nodes, addrMaybeId{a.KRPC(), nil})
927 928
		}
	}
929
	if len(nodes) == 0 {
Matt Joiner's avatar
Matt Joiner committed
930 931 932
		err = errors.New("no initial nodes")
	}
	return
933
}
934 935 936 937 938 939 940 941