diff options
author | anon5 <anon5clam@protonmail.com> | 2021-03-22 18:37:36 +0000 |
---|---|---|
committer | anon5 <anon5clam@protonmail.com> | 2021-03-22 18:37:36 +0000 |
commit | 433955e45ef35476ab6863c4aff06b6ddd0bec67 (patch) | |
tree | 354f9f143c4ed56c79519519cf9d7af09802fd9a /rudp/send.go | |
parent | 1bae1f4f7ea371086d14b641d6a45aaa8feccb75 (diff) |
rudp: partial rewrite with new API supporting io.Readers
Diffstat (limited to 'rudp/send.go')
-rw-r--r-- | rudp/send.go | 358 |
1 files changed, 169 insertions, 189 deletions
diff --git a/rudp/send.go b/rudp/send.go index c43d056..5522ee0 100644 --- a/rudp/send.go +++ b/rudp/send.go @@ -1,241 +1,221 @@ package rudp import ( + "bytes" "errors" "fmt" - "math" + "io" "net" "sync" + "sync/atomic" "time" ) -const ( - // protoID + src PeerID + channel number - MtHdrSize = 4 + 2 + 1 - - // rawTypeOrig - OrigHdrSize = 1 - - // rawTypeSpilt + seqnum + chunk count + chunk number - SplitHdrSize = 1 + 2 + 2 + 2 - - // rawTypeRel + seqnum - RelHdrSize = 1 + 2 -) - -const ( - MaxNetPktSize = 512 - - MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize - MaxRelRawPktSize = MaxUnrelRawPktSize - RelHdrSize - - MaxRelPktSize = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16 - MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16 -) - var ErrPktTooBig = errors.New("can't send pkt: too big") -var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount") - -// Send sends a packet to the Peer. -// It returns a channel that's closed when all chunks are acked or an error. -// The ack channel is nil if pkt.Unrel is true. -func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) { - if pkt.ChNo >= ChannelCount { - return nil, ErrChNoTooBig - } - - hdrSize := MtHdrSize - if !pkt.Unrel { - hdrSize += RelHdrSize - } - - if hdrSize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize { - c := &p.chans[pkt.ChNo] - - c.outSplitMu.Lock() - sn := c.outSplitSN - c.outSplitSN++ - c.outSplitMu.Unlock() - chunks := split(pkt.Data, MaxNetPktSize-(hdrSize+SplitHdrSize)) +// A TooBigChError reports a Channel greater than or equal to ChannelCount. +type TooBigChError Channel - if len(chunks) > math.MaxUint16 { - return nil, ErrPktTooBig - } +func (e TooBigChError) Error() string { + return fmt.Sprintf("channel >= ChannelCount (%d): %d", ChannelCount, e) +} - var wg sync.WaitGroup +// Send sends a Pkt to the Conn. +// Ack is closed when the packet is acknowledged. +// Ack is nil if pkt.Unrel is true or err != nil. +func (c *Conn) Send(pkt Pkt) (ack <-chan struct{}, err error) { + if pkt.Channel >= ChannelCount { + return nil, TooBigChError(pkt.Channel) + } - for i, chunk := range chunks { - data := make([]byte, SplitHdrSize+len(chunk)) - data[0] = uint8(rawTypeSplit) - be.PutUint16(data[1:3], uint16(sn)) - be.PutUint16(data[3:5], uint16(len(chunks))) - be.PutUint16(data[5:7], uint16(i)) - copy(data[SplitHdrSize:], chunk) + var e error + send := c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawOrig) - wg.Add(1) - ack, err := p.sendRaw(rawPkt{ - Data: data, - ChNo: pkt.ChNo, - Unrel: pkt.Unrel, - }) + nn := 1 + for nn < len(buf) { + n, err := pkt.Read(buf[nn:]) + nn += n if err != nil { - return nil, err - } - if !pkt.Unrel { - go func() { - <-ack - wg.Done() - }() + e = err + return nn } } - if pkt.Unrel { - return nil, nil - } else { - ack := make(chan struct{}) - - go func() { - wg.Wait() - close(ack) - }() + if _, e = pkt.Read(nil); e != nil { + return nn + } - return ack, nil + pkt.Reader = io.MultiReader( + bytes.NewReader([]byte(buf[1:nn])), + pkt.Reader, + ) + return nn + }, pkt.PktInfo) + if e != nil { + if e == io.EOF { + return send() } + return nil, e } - return p.sendRaw(rawPkt{ - Data: append([]byte{uint8(rawTypeOrig)}, pkt.Data...), - ChNo: pkt.ChNo, - Unrel: pkt.Unrel, - }) -} + var ( + sn seqnum + i uint16 + + sends []func() (<-chan struct{}, error) + ) + + for { + var ( + b []byte + e error + ) + send := c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawSplit) + + n, err := io.ReadFull(pkt, buf[7:]) + if err != nil && err != io.ErrUnexpectedEOF { + e = err + return 0 + } -// sendRaw sends a raw packet to the Peer. -func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) { - if pkt.ChNo >= ChannelCount { - return nil, ErrChNoTooBig - } + be.PutUint16(buf[5:7], i) + if i++; i == 0 { + e = ErrPktTooBig + return 0 + } - p.mu.RLock() - defer p.mu.RUnlock() + b = buf + return 7 + n + }, pkt.PktInfo) + if e != nil { + if e == io.EOF { + break + } + return nil, e + } - select { - case <-p.Disco(): - return nil, net.ErrClosed - default: + sends = append(sends, func() (<-chan struct{}, error) { + be.PutUint16(b[1:3], uint16(sn)) + be.PutUint16(b[3:5], i) + return send() + }) } - if !pkt.Unrel { - return p.sendRel(pkt) - } + ch := &c.chans[pkt.Channel] - data := make([]byte, MtHdrSize+len(pkt.Data)) - be.PutUint32(data[0:4], protoID) - be.PutUint16(data[4:6], uint16(p.idOfPeer)) - data[6] = pkt.ChNo - copy(data[MtHdrSize:], pkt.Data) + ch.outSplitMu.Lock() + sn = ch.outSplitSN + ch.outSplitSN++ + ch.outSplitMu.Unlock() - if len(data) > MaxNetPktSize { - return nil, ErrPktTooBig - } + var wg sync.WaitGroup - if p.conn != nil { - _, err = p.conn.Write(data) - } else { - _, err = p.pc.WriteTo(data, p.Addr()) - } - if err != nil { - return nil, err + for _, send := range sends { + ack, err := send() + if err != nil { + return nil, err + } + if !pkt.Unrel { + wg.Add(1) + go func() { + <-ack + wg.Done() + }() + } } - p.ping.Reset(PingTimeout) + if !pkt.Unrel { + ack := make(chan struct{}) + go func() { + wg.Wait() + close(ack) + }() + return ack, nil + } return nil, nil } -// sendRel sends a reliable raw packet to the Peer. -func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) { - if pkt.Unrel { - panic("pkt.Unrel is true") - } - - c := &p.chans[pkt.ChNo] +func (c *Conn) sendRaw(read func([]byte) int, pi PktInfo) func() (<-chan struct{}, error) { + if pi.Unrel { + buf := make([]byte, maxUDPPktSize) + be.PutUint32(buf[0:4], protoID) + c.mu.RLock() + be.PutUint16(buf[4:6], uint16(c.remoteID)) + c.mu.RUnlock() + buf[6] = uint8(pi.Channel) + buf = buf[:7+read(buf[7:])] + + return func() (<-chan struct{}, error) { + if _, err := c.udpConn.Write(buf); err != nil { + c.close(err) + return nil, net.ErrClosed + } - c.outRelMu.Lock() - defer c.outRelMu.Unlock() + c.ping.Reset(PingTimeout) + if atomic.LoadUint32(&c.closing) == 1 { + c.ping.Stop() + } - sn := c.outRelSN - for ; sn-c.outRelWin >= 0x8000; c.outRelWin++ { - if ack, ok := c.ackChans.Load(c.outRelWin); ok { - <-ack.(chan struct{}) + return nil, nil } } - c.outRelSN++ - - rwack := make(chan struct{}) // close-only - c.ackChans.Store(sn, rwack) - ack = rwack - - data := make([]byte, RelHdrSize+len(pkt.Data)) - data[0] = uint8(rawTypeRel) - be.PutUint16(data[1:3], uint16(sn)) - copy(data[RelHdrSize:], pkt.Data) - rel := rawPkt{ - Data: data, - ChNo: pkt.ChNo, - Unrel: true, - } - - if _, err := p.sendRaw(rel); err != nil { - c.ackChans.Delete(sn) - return nil, err - } - - go func() { - for { - select { - case <-time.After(500 * time.Millisecond): - if _, err := p.sendRaw(rel); err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err) + pi.Unrel = true + var snBuf []byte + send := c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawRel) + snBuf = buf[1:3] + return 3 + read(buf[3:]) + }, pi) + + return func() (<-chan struct{}, error) { + ch := &c.chans[pi.Channel] + + ch.outRelMu.Lock() + defer ch.outRelMu.Unlock() + + sn := ch.outRelSN + be.PutUint16(snBuf, uint16(sn)) + for ; sn-ch.outRelWin >= 0x8000; ch.outRelWin++ { + if ack, ok := ch.ackChans.Load(ch.outRelWin); ok { + select { + case <-ack.(chan struct{}): + case <-c.Closed(): } - case <-ack: - return - case <-p.Disco(): - return } } - }() - - return ack, nil -} - -// SendDisco sends a disconnect packet to the Peer but does not close it. -// It returns a channel that's closed when it's acked or an error. -// The ack channel is nil if unrel is true. -func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) { - return p.sendRaw(rawPkt{ - Data: []byte{uint8(rawTypeCtl), uint8(ctlDisco)}, - ChNo: chno, - Unrel: unrel, - }) -} -func split(data []byte, chunksize int) [][]byte { - chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize) + ack := make(chan struct{}) + ch.ackChans.Store(sn, ack) - for i := 0; i < len(data); i += chunksize { - end := i + chunksize - if end > len(data) { - end = len(data) + if _, err := send(); err != nil { + if ack, ok := ch.ackChans.LoadAndDelete(sn); ok { + close(ack.(chan struct{})) + } + return nil, err } + ch.outRelSN++ + + go func() { + t := time.NewTimer(500 * time.Millisecond) + defer t.Stop() + + for { + select { + case <-ack: + return + case <-t.C: + send() + t.Reset(500 * time.Millisecond) + case <-c.Closed(): + return + } + } + }() - chunks = append(chunks, data[i:end]) + return ack, nil } - - return chunks } |