summaryrefslogtreecommitdiff
path: root/rudp/listen.go
blob: 871a5912adcc18a0ad005c39f01454e0543cd98f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package rudp

import (
	"encoding/binary"
	"errors"
	"fmt"
	"net"
	"sync"
)

type Listener struct {
	conn net.PacketConn

	clts chan cltPeer
	errs chan error

	mu        sync.Mutex
	addr2peer map[string]cltPeer
	id2peer   map[PeerID]cltPeer
	peerid    PeerID
}

// Listen listens for packets on conn until it is closed.
func Listen(conn net.PacketConn) *Listener {
	l := &Listener{
		conn: conn,

		clts: make(chan cltPeer),
		errs: make(chan error),

		addr2peer: make(map[string]cltPeer),
		id2peer:   make(map[PeerID]cltPeer),
	}

	pkts := make(chan netPkt)
	go readNetPkts(l.conn, pkts, l.errs)
	go func() {
		for pkt := range pkts {
			if err := l.processNetPkt(pkt); err != nil {
				l.errs <- err
			}
		}

		close(l.clts)

		for _, clt := range l.addr2peer {
			clt.Close()
		}
	}()

	return l
}

// Accept waits for and returns a connecting Peer.
// You should keep calling this until it returns ErrClosed
// so it doesn't leak a goroutine.
func (l *Listener) Accept() (*Peer, error) {
	select {
	case clt, ok := <-l.clts:
		if !ok {
			select {
			case err := <-l.errs:
				return nil, err
			default:
				return nil, ErrClosed
			}
		}
		close(clt.accepted)
		return clt.Peer, nil
	case err := <-l.errs:
		return nil, err
	}
}

// Addr returns the net.PacketConn the Listener is listening on.
func (l *Listener) Conn() net.PacketConn { return l.conn }

var ErrOutOfPeerIDs = errors.New("out of peer ids")

type cltPeer struct {
	*Peer
	pkts     chan<- netPkt
	accepted chan struct{} // close-only
}

func (l *Listener) processNetPkt(pkt netPkt) error {
	l.mu.Lock()
	defer l.mu.Unlock()

	addrstr := pkt.SrcAddr.String()

	clt, ok := l.addr2peer[addrstr]
	if !ok {
		prev := l.peerid
		for l.id2peer[l.peerid].Peer != nil || l.peerid < PeerIDCltMin {
			if l.peerid == prev-1 {
				return ErrOutOfPeerIDs
			}
			l.peerid++
		}

		pkts := make(chan netPkt, 256)

		clt = cltPeer{
			Peer:     newPeer(l.conn, pkt.SrcAddr, l.peerid, PeerIDSrv),
			pkts:     pkts,
			accepted: make(chan struct{}),
		}

		l.addr2peer[addrstr] = clt
		l.id2peer[clt.ID()] = clt

		data := make([]byte, 1+1+2)
		data[0] = uint8(rawTypeCtl)
		data[1] = uint8(ctlSetPeerID)
		binary.BigEndian.PutUint16(data[2:4], uint16(clt.ID()))
		if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
			return fmt.Errorf("can't set client peer id: %w", err)
		}

		go func() {
			select {
			case l.clts <- clt:
			case <-clt.Disco():
			}

			clt.processNetPkts(pkts)
		}()

		go func() {
			<-clt.Disco()

			l.mu.Lock()
			close(pkts)
			delete(l.addr2peer, addrstr)
			delete(l.id2peer, clt.ID())
			l.mu.Unlock()
		}()
	}

	select {
	case <-clt.accepted:
		clt.pkts <- pkt
	default:
		select {
		case clt.pkts <- pkt:
		default:
			return fmt.Errorf("ignoring net pkt from %s because buf is full", addrstr)
		}
	}

	return nil
}