summaryrefslogtreecommitdiff
path: root/rudp/recv.go
diff options
context:
space:
mode:
Diffstat (limited to 'rudp/recv.go')
-rw-r--r--rudp/recv.go259
1 files changed, 259 insertions, 0 deletions
diff --git a/rudp/recv.go b/rudp/recv.go
new file mode 100644
index 0000000..f5ac236
--- /dev/null
+++ b/rudp/recv.go
@@ -0,0 +1,259 @@
+package rudp
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "time"
+)
+
+// Recv receives a Pkt from the Conn.
+func (c *Conn) Recv() (Pkt, error) {
+ select {
+ case pkt := <-c.pkts:
+ return pkt, nil
+ case err := <-c.errs:
+ return Pkt{}, err
+ case <-c.Closed():
+ return Pkt{}, net.ErrClosed
+ }
+}
+
+func (c *Conn) gotPkt(pkt Pkt) {
+ select {
+ case c.pkts <- pkt:
+ case <-c.Closed():
+ }
+}
+
+func (c *Conn) gotErr(kind string, data []byte, err error) {
+ select {
+ case c.errs <- fmt.Errorf("%s: %x: %w", kind, data, err):
+ case <-c.Closed():
+ }
+}
+
+func (c *Conn) recvUDPPkts() {
+ for {
+ pkt, err := c.udpConn.recvUDP()
+ if err != nil {
+ c.closeDisco(err)
+ break
+ }
+
+ if err := c.processUDPPkt(pkt); err != nil {
+ c.gotErr("udp", pkt, err)
+ }
+ }
+}
+
+func (c *Conn) processUDPPkt(pkt []byte) error {
+ if c.timeout.Stop() {
+ c.timeout.Reset(ConnTimeout)
+ }
+
+ if len(pkt) < 6 {
+ return io.ErrUnexpectedEOF
+ }
+
+ if id := be.Uint32(pkt[0:4]); id != protoID {
+ return fmt.Errorf("unsupported protocol id: 0x%08x", id)
+ }
+
+ ch := Channel(pkt[6])
+ if ch >= ChannelCount {
+ return TooBigChError(ch)
+ }
+
+ if err := c.processRawPkt(pkt[7:], PktInfo{Channel: ch, Unrel: true}); err != nil {
+ c.gotErr("raw", pkt, err)
+ }
+
+ return nil
+}
+
+// A TrailingDataError reports trailing data after a packet.
+type TrailingDataError []byte
+
+func (e TrailingDataError) Error() string {
+ return fmt.Sprintf("trailing data: %x", []byte(e))
+}
+
+func (c *Conn) processRawPkt(data []byte, pi PktInfo) (err error) {
+ errWrap := func(format string, a ...interface{}) {
+ if err != nil {
+ err = fmt.Errorf(format+": %w", append(a, err)...)
+ }
+ }
+
+ eof := new(byte)
+ defer func() {
+ switch r := recover(); r {
+ case nil:
+ case eof:
+ err = io.ErrUnexpectedEOF
+ default:
+ panic(r)
+ }
+ }()
+
+ off := 0
+ eat := func(n int) []byte {
+ i := off
+ off += n
+ if i > len(data) {
+ panic(eof)
+ }
+ return data[i:off]
+ }
+
+ ch := &c.chans[pi.Channel]
+
+ switch t := rawType(eat(1)[0]); t {
+ case rawCtl:
+ defer errWrap("ctl")
+
+ switch ct := ctlType(eat(1)[0]); ct {
+ case ctlAck:
+ defer errWrap("ack")
+
+ sn := seqnum(be.Uint16(eat(2)))
+
+ if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
+ close(ack.(chan struct{}))
+ }
+ case ctlSetPeerID:
+ defer errWrap("set peer id")
+
+ c.mu.Lock()
+ if c.remoteID != PeerIDNil {
+ return errors.New("peer id already set")
+ }
+
+ c.remoteID = PeerID(be.Uint16(eat(2)))
+ c.mu.Unlock()
+
+ c.newAckBuf()
+ case ctlPing:
+ defer errWrap("ping")
+ case ctlDisco:
+ defer errWrap("disco")
+
+ c.close(nil)
+ default:
+ return fmt.Errorf("unsupported ctl type: %d", ct)
+ }
+
+ if off < len(data) {
+ return TrailingDataError(data[off:])
+ }
+ case rawOrig:
+ c.gotPkt(Pkt{
+ Reader: bytes.NewReader(data[off:]),
+ PktInfo: pi,
+ })
+ case rawSplit:
+ defer errWrap("split")
+
+ sn := seqnum(be.Uint16(eat(2)))
+ n := be.Uint16(eat(2))
+ i := be.Uint16(eat(2))
+
+ defer errWrap("%d", sn)
+
+ if i >= n {
+ return fmt.Errorf("chunk number (%d) > chunk count (%d)", i, n)
+ }
+
+ ch.inSplitsMu.RLock()
+ s := ch.inSplits[sn]
+ ch.inSplitsMu.RUnlock()
+
+ if s == nil {
+ s = &inSplit{chunks: make([][]byte, n)}
+ if pi.Unrel {
+ s.timeout = time.AfterFunc(ConnTimeout, func() {
+ ch.inSplitsMu.Lock()
+ delete(ch.inSplits, sn)
+ ch.inSplitsMu.Unlock()
+ })
+ }
+
+ ch.inSplitsMu.Lock()
+ ch.inSplits[sn] = s
+ ch.inSplitsMu.Unlock()
+ }
+
+ if int(n) != len(s.chunks) {
+ return fmt.Errorf("chunk count changed from %d to %d", len(s.chunks), n)
+ }
+
+ if s.chunks[i] == nil {
+ s.chunks[i] = data[off:]
+ s.got++
+ }
+
+ if s.got < len(s.chunks) {
+ if s.timeout != nil && s.timeout.Stop() {
+ s.timeout.Reset(ConnTimeout)
+ }
+ return
+ }
+
+ if s.timeout != nil {
+ s.timeout.Stop()
+ }
+
+ ch.inSplitsMu.Lock()
+ delete(ch.inSplits, sn)
+ ch.inSplitsMu.Unlock()
+
+ c.gotPkt(Pkt{
+ Reader: (*net.Buffers)(&s.chunks),
+ PktInfo: pi,
+ })
+ case rawRel:
+ defer errWrap("rel")
+
+ sn := seqnum(be.Uint16(eat(2)))
+
+ defer errWrap("%d", sn)
+
+ be.PutUint16(ch.ackBuf, uint16(sn))
+ ch.sendAck()
+
+ if sn-ch.inRelSN >= 0x8000 {
+ // Already received.
+ return nil
+ }
+
+ ch.inRels[sn&0x7fff] = data[off:]
+
+ i := func() seqnum { return ch.inRelSN & 0x7fff }
+ for ; ch.inRels[i()] != nil; ch.inRelSN++ {
+ data := ch.inRels[i()]
+ ch.inRels[i()] = nil
+ if err := c.processRawPkt(data, PktInfo{Channel: pi.Channel}); err != nil {
+ c.gotErr("rel", data, err)
+ }
+ }
+ default:
+ return fmt.Errorf("unsupported pkt type: %d", t)
+ }
+
+ return nil
+}
+
+func (c *Conn) newAckBuf() {
+ for i := range c.chans {
+ ch := &c.chans[i]
+ ch.sendAck = c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawCtl)
+ buf[1] = uint8(ctlAck)
+ ch.ackBuf = buf[2:4]
+ return 4
+ }, PktInfo{Channel: Channel(i), Unrel: true})
+ }
+}