diff options
Diffstat (limited to 'rudp/listen.go')
-rw-r--r-- | rudp/listen.go | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/rudp/listen.go b/rudp/listen.go new file mode 100644 index 0000000..871a591 --- /dev/null +++ b/rudp/listen.go @@ -0,0 +1,153 @@ +package rudp + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "sync" +) + +type Listener struct { + conn net.PacketConn + + clts chan cltPeer + errs chan error + + mu sync.Mutex + addr2peer map[string]cltPeer + id2peer map[PeerID]cltPeer + peerid PeerID +} + +// Listen listens for packets on conn until it is closed. +func Listen(conn net.PacketConn) *Listener { + l := &Listener{ + conn: conn, + + clts: make(chan cltPeer), + errs: make(chan error), + + addr2peer: make(map[string]cltPeer), + id2peer: make(map[PeerID]cltPeer), + } + + pkts := make(chan netPkt) + go readNetPkts(l.conn, pkts, l.errs) + go func() { + for pkt := range pkts { + if err := l.processNetPkt(pkt); err != nil { + l.errs <- err + } + } + + close(l.clts) + + for _, clt := range l.addr2peer { + clt.Close() + } + }() + + return l +} + +// Accept waits for and returns a connecting Peer. +// You should keep calling this until it returns ErrClosed +// so it doesn't leak a goroutine. +func (l *Listener) Accept() (*Peer, error) { + select { + case clt, ok := <-l.clts: + if !ok { + select { + case err := <-l.errs: + return nil, err + default: + return nil, ErrClosed + } + } + close(clt.accepted) + return clt.Peer, nil + case err := <-l.errs: + return nil, err + } +} + +// Addr returns the net.PacketConn the Listener is listening on. +func (l *Listener) Conn() net.PacketConn { return l.conn } + +var ErrOutOfPeerIDs = errors.New("out of peer ids") + +type cltPeer struct { + *Peer + pkts chan<- netPkt + accepted chan struct{} // close-only +} + +func (l *Listener) processNetPkt(pkt netPkt) error { + l.mu.Lock() + defer l.mu.Unlock() + + addrstr := pkt.SrcAddr.String() + + clt, ok := l.addr2peer[addrstr] + if !ok { + prev := l.peerid + for l.id2peer[l.peerid].Peer != nil || l.peerid < PeerIDCltMin { + if l.peerid == prev-1 { + return ErrOutOfPeerIDs + } + l.peerid++ + } + + pkts := make(chan netPkt, 256) + + clt = cltPeer{ + Peer: newPeer(l.conn, pkt.SrcAddr, l.peerid, PeerIDSrv), + pkts: pkts, + accepted: make(chan struct{}), + } + + l.addr2peer[addrstr] = clt + l.id2peer[clt.ID()] = clt + + data := make([]byte, 1+1+2) + data[0] = uint8(rawTypeCtl) + data[1] = uint8(ctlSetPeerID) + binary.BigEndian.PutUint16(data[2:4], uint16(clt.ID())) + if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil { + return fmt.Errorf("can't set client peer id: %w", err) + } + + go func() { + select { + case l.clts <- clt: + case <-clt.Disco(): + } + + clt.processNetPkts(pkts) + }() + + go func() { + <-clt.Disco() + + l.mu.Lock() + close(pkts) + delete(l.addr2peer, addrstr) + delete(l.id2peer, clt.ID()) + l.mu.Unlock() + }() + } + + select { + case <-clt.accepted: + clt.pkts <- pkt + default: + select { + case clt.pkts <- pkt: + default: + return fmt.Errorf("ignoring net pkt from %s because buf is full", addrstr) + } + } + + return nil +} |