diff options
Diffstat (limited to 'rudp/recv.go')
-rw-r--r-- | rudp/recv.go | 259 |
1 files changed, 259 insertions, 0 deletions
diff --git a/rudp/recv.go b/rudp/recv.go new file mode 100644 index 0000000..f5ac236 --- /dev/null +++ b/rudp/recv.go @@ -0,0 +1,259 @@ +package rudp + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "time" +) + +// Recv receives a Pkt from the Conn. +func (c *Conn) Recv() (Pkt, error) { + select { + case pkt := <-c.pkts: + return pkt, nil + case err := <-c.errs: + return Pkt{}, err + case <-c.Closed(): + return Pkt{}, net.ErrClosed + } +} + +func (c *Conn) gotPkt(pkt Pkt) { + select { + case c.pkts <- pkt: + case <-c.Closed(): + } +} + +func (c *Conn) gotErr(kind string, data []byte, err error) { + select { + case c.errs <- fmt.Errorf("%s: %x: %w", kind, data, err): + case <-c.Closed(): + } +} + +func (c *Conn) recvUDPPkts() { + for { + pkt, err := c.udpConn.recvUDP() + if err != nil { + c.closeDisco(err) + break + } + + if err := c.processUDPPkt(pkt); err != nil { + c.gotErr("udp", pkt, err) + } + } +} + +func (c *Conn) processUDPPkt(pkt []byte) error { + if c.timeout.Stop() { + c.timeout.Reset(ConnTimeout) + } + + if len(pkt) < 6 { + return io.ErrUnexpectedEOF + } + + if id := be.Uint32(pkt[0:4]); id != protoID { + return fmt.Errorf("unsupported protocol id: 0x%08x", id) + } + + ch := Channel(pkt[6]) + if ch >= ChannelCount { + return TooBigChError(ch) + } + + if err := c.processRawPkt(pkt[7:], PktInfo{Channel: ch, Unrel: true}); err != nil { + c.gotErr("raw", pkt, err) + } + + return nil +} + +// A TrailingDataError reports trailing data after a packet. +type TrailingDataError []byte + +func (e TrailingDataError) Error() string { + return fmt.Sprintf("trailing data: %x", []byte(e)) +} + +func (c *Conn) processRawPkt(data []byte, pi PktInfo) (err error) { + errWrap := func(format string, a ...interface{}) { + if err != nil { + err = fmt.Errorf(format+": %w", append(a, err)...) + } + } + + eof := new(byte) + defer func() { + switch r := recover(); r { + case nil: + case eof: + err = io.ErrUnexpectedEOF + default: + panic(r) + } + }() + + off := 0 + eat := func(n int) []byte { + i := off + off += n + if i > len(data) { + panic(eof) + } + return data[i:off] + } + + ch := &c.chans[pi.Channel] + + switch t := rawType(eat(1)[0]); t { + case rawCtl: + defer errWrap("ctl") + + switch ct := ctlType(eat(1)[0]); ct { + case ctlAck: + defer errWrap("ack") + + sn := seqnum(be.Uint16(eat(2))) + + if ack, ok := ch.ackChans.LoadAndDelete(sn); ok { + close(ack.(chan struct{})) + } + case ctlSetPeerID: + defer errWrap("set peer id") + + c.mu.Lock() + if c.remoteID != PeerIDNil { + return errors.New("peer id already set") + } + + c.remoteID = PeerID(be.Uint16(eat(2))) + c.mu.Unlock() + + c.newAckBuf() + case ctlPing: + defer errWrap("ping") + case ctlDisco: + defer errWrap("disco") + + c.close(nil) + default: + return fmt.Errorf("unsupported ctl type: %d", ct) + } + + if off < len(data) { + return TrailingDataError(data[off:]) + } + case rawOrig: + c.gotPkt(Pkt{ + Reader: bytes.NewReader(data[off:]), + PktInfo: pi, + }) + case rawSplit: + defer errWrap("split") + + sn := seqnum(be.Uint16(eat(2))) + n := be.Uint16(eat(2)) + i := be.Uint16(eat(2)) + + defer errWrap("%d", sn) + + if i >= n { + return fmt.Errorf("chunk number (%d) > chunk count (%d)", i, n) + } + + ch.inSplitsMu.RLock() + s := ch.inSplits[sn] + ch.inSplitsMu.RUnlock() + + if s == nil { + s = &inSplit{chunks: make([][]byte, n)} + if pi.Unrel { + s.timeout = time.AfterFunc(ConnTimeout, func() { + ch.inSplitsMu.Lock() + delete(ch.inSplits, sn) + ch.inSplitsMu.Unlock() + }) + } + + ch.inSplitsMu.Lock() + ch.inSplits[sn] = s + ch.inSplitsMu.Unlock() + } + + if int(n) != len(s.chunks) { + return fmt.Errorf("chunk count changed from %d to %d", len(s.chunks), n) + } + + if s.chunks[i] == nil { + s.chunks[i] = data[off:] + s.got++ + } + + if s.got < len(s.chunks) { + if s.timeout != nil && s.timeout.Stop() { + s.timeout.Reset(ConnTimeout) + } + return + } + + if s.timeout != nil { + s.timeout.Stop() + } + + ch.inSplitsMu.Lock() + delete(ch.inSplits, sn) + ch.inSplitsMu.Unlock() + + c.gotPkt(Pkt{ + Reader: (*net.Buffers)(&s.chunks), + PktInfo: pi, + }) + case rawRel: + defer errWrap("rel") + + sn := seqnum(be.Uint16(eat(2))) + + defer errWrap("%d", sn) + + be.PutUint16(ch.ackBuf, uint16(sn)) + ch.sendAck() + + if sn-ch.inRelSN >= 0x8000 { + // Already received. + return nil + } + + ch.inRels[sn&0x7fff] = data[off:] + + i := func() seqnum { return ch.inRelSN & 0x7fff } + for ; ch.inRels[i()] != nil; ch.inRelSN++ { + data := ch.inRels[i()] + ch.inRels[i()] = nil + if err := c.processRawPkt(data, PktInfo{Channel: pi.Channel}); err != nil { + c.gotErr("rel", data, err) + } + } + default: + return fmt.Errorf("unsupported pkt type: %d", t) + } + + return nil +} + +func (c *Conn) newAckBuf() { + for i := range c.chans { + ch := &c.chans[i] + ch.sendAck = c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawCtl) + buf[1] = uint8(ctlAck) + ch.ackBuf = buf[2:4] + return 4 + }, PktInfo{Channel: Channel(i), Unrel: true}) + } +} |