summaryrefslogtreecommitdiff
path: root/rudp/send.go
diff options
context:
space:
mode:
authoranon5 <anon5clam@protonmail.com>2021-03-22 18:37:36 +0000
committeranon5 <anon5clam@protonmail.com>2021-03-22 18:37:36 +0000
commit433955e45ef35476ab6863c4aff06b6ddd0bec67 (patch)
tree354f9f143c4ed56c79519519cf9d7af09802fd9a /rudp/send.go
parent1bae1f4f7ea371086d14b641d6a45aaa8feccb75 (diff)
rudp: partial rewrite with new API supporting io.Readers
Diffstat (limited to 'rudp/send.go')
-rw-r--r--rudp/send.go358
1 files changed, 169 insertions, 189 deletions
diff --git a/rudp/send.go b/rudp/send.go
index c43d056..5522ee0 100644
--- a/rudp/send.go
+++ b/rudp/send.go
@@ -1,241 +1,221 @@
package rudp
import (
+ "bytes"
"errors"
"fmt"
- "math"
+ "io"
"net"
"sync"
+ "sync/atomic"
"time"
)
-const (
- // protoID + src PeerID + channel number
- MtHdrSize = 4 + 2 + 1
-
- // rawTypeOrig
- OrigHdrSize = 1
-
- // rawTypeSpilt + seqnum + chunk count + chunk number
- SplitHdrSize = 1 + 2 + 2 + 2
-
- // rawTypeRel + seqnum
- RelHdrSize = 1 + 2
-)
-
-const (
- MaxNetPktSize = 512
-
- MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize
- MaxRelRawPktSize = MaxUnrelRawPktSize - RelHdrSize
-
- MaxRelPktSize = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16
- MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16
-)
-
var ErrPktTooBig = errors.New("can't send pkt: too big")
-var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount")
-
-// Send sends a packet to the Peer.
-// It returns a channel that's closed when all chunks are acked or an error.
-// The ack channel is nil if pkt.Unrel is true.
-func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
- if pkt.ChNo >= ChannelCount {
- return nil, ErrChNoTooBig
- }
-
- hdrSize := MtHdrSize
- if !pkt.Unrel {
- hdrSize += RelHdrSize
- }
-
- if hdrSize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize {
- c := &p.chans[pkt.ChNo]
-
- c.outSplitMu.Lock()
- sn := c.outSplitSN
- c.outSplitSN++
- c.outSplitMu.Unlock()
- chunks := split(pkt.Data, MaxNetPktSize-(hdrSize+SplitHdrSize))
+// A TooBigChError reports a Channel greater than or equal to ChannelCount.
+type TooBigChError Channel
- if len(chunks) > math.MaxUint16 {
- return nil, ErrPktTooBig
- }
+func (e TooBigChError) Error() string {
+ return fmt.Sprintf("channel >= ChannelCount (%d): %d", ChannelCount, e)
+}
- var wg sync.WaitGroup
+// Send sends a Pkt to the Conn.
+// Ack is closed when the packet is acknowledged.
+// Ack is nil if pkt.Unrel is true or err != nil.
+func (c *Conn) Send(pkt Pkt) (ack <-chan struct{}, err error) {
+ if pkt.Channel >= ChannelCount {
+ return nil, TooBigChError(pkt.Channel)
+ }
- for i, chunk := range chunks {
- data := make([]byte, SplitHdrSize+len(chunk))
- data[0] = uint8(rawTypeSplit)
- be.PutUint16(data[1:3], uint16(sn))
- be.PutUint16(data[3:5], uint16(len(chunks)))
- be.PutUint16(data[5:7], uint16(i))
- copy(data[SplitHdrSize:], chunk)
+ var e error
+ send := c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawOrig)
- wg.Add(1)
- ack, err := p.sendRaw(rawPkt{
- Data: data,
- ChNo: pkt.ChNo,
- Unrel: pkt.Unrel,
- })
+ nn := 1
+ for nn < len(buf) {
+ n, err := pkt.Read(buf[nn:])
+ nn += n
if err != nil {
- return nil, err
- }
- if !pkt.Unrel {
- go func() {
- <-ack
- wg.Done()
- }()
+ e = err
+ return nn
}
}
- if pkt.Unrel {
- return nil, nil
- } else {
- ack := make(chan struct{})
-
- go func() {
- wg.Wait()
- close(ack)
- }()
+ if _, e = pkt.Read(nil); e != nil {
+ return nn
+ }
- return ack, nil
+ pkt.Reader = io.MultiReader(
+ bytes.NewReader([]byte(buf[1:nn])),
+ pkt.Reader,
+ )
+ return nn
+ }, pkt.PktInfo)
+ if e != nil {
+ if e == io.EOF {
+ return send()
}
+ return nil, e
}
- return p.sendRaw(rawPkt{
- Data: append([]byte{uint8(rawTypeOrig)}, pkt.Data...),
- ChNo: pkt.ChNo,
- Unrel: pkt.Unrel,
- })
-}
+ var (
+ sn seqnum
+ i uint16
+
+ sends []func() (<-chan struct{}, error)
+ )
+
+ for {
+ var (
+ b []byte
+ e error
+ )
+ send := c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawSplit)
+
+ n, err := io.ReadFull(pkt, buf[7:])
+ if err != nil && err != io.ErrUnexpectedEOF {
+ e = err
+ return 0
+ }
-// sendRaw sends a raw packet to the Peer.
-func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) {
- if pkt.ChNo >= ChannelCount {
- return nil, ErrChNoTooBig
- }
+ be.PutUint16(buf[5:7], i)
+ if i++; i == 0 {
+ e = ErrPktTooBig
+ return 0
+ }
- p.mu.RLock()
- defer p.mu.RUnlock()
+ b = buf
+ return 7 + n
+ }, pkt.PktInfo)
+ if e != nil {
+ if e == io.EOF {
+ break
+ }
+ return nil, e
+ }
- select {
- case <-p.Disco():
- return nil, net.ErrClosed
- default:
+ sends = append(sends, func() (<-chan struct{}, error) {
+ be.PutUint16(b[1:3], uint16(sn))
+ be.PutUint16(b[3:5], i)
+ return send()
+ })
}
- if !pkt.Unrel {
- return p.sendRel(pkt)
- }
+ ch := &c.chans[pkt.Channel]
- data := make([]byte, MtHdrSize+len(pkt.Data))
- be.PutUint32(data[0:4], protoID)
- be.PutUint16(data[4:6], uint16(p.idOfPeer))
- data[6] = pkt.ChNo
- copy(data[MtHdrSize:], pkt.Data)
+ ch.outSplitMu.Lock()
+ sn = ch.outSplitSN
+ ch.outSplitSN++
+ ch.outSplitMu.Unlock()
- if len(data) > MaxNetPktSize {
- return nil, ErrPktTooBig
- }
+ var wg sync.WaitGroup
- if p.conn != nil {
- _, err = p.conn.Write(data)
- } else {
- _, err = p.pc.WriteTo(data, p.Addr())
- }
- if err != nil {
- return nil, err
+ for _, send := range sends {
+ ack, err := send()
+ if err != nil {
+ return nil, err
+ }
+ if !pkt.Unrel {
+ wg.Add(1)
+ go func() {
+ <-ack
+ wg.Done()
+ }()
+ }
}
- p.ping.Reset(PingTimeout)
+ if !pkt.Unrel {
+ ack := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(ack)
+ }()
+ return ack, nil
+ }
return nil, nil
}
-// sendRel sends a reliable raw packet to the Peer.
-func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) {
- if pkt.Unrel {
- panic("pkt.Unrel is true")
- }
-
- c := &p.chans[pkt.ChNo]
+func (c *Conn) sendRaw(read func([]byte) int, pi PktInfo) func() (<-chan struct{}, error) {
+ if pi.Unrel {
+ buf := make([]byte, maxUDPPktSize)
+ be.PutUint32(buf[0:4], protoID)
+ c.mu.RLock()
+ be.PutUint16(buf[4:6], uint16(c.remoteID))
+ c.mu.RUnlock()
+ buf[6] = uint8(pi.Channel)
+ buf = buf[:7+read(buf[7:])]
+
+ return func() (<-chan struct{}, error) {
+ if _, err := c.udpConn.Write(buf); err != nil {
+ c.close(err)
+ return nil, net.ErrClosed
+ }
- c.outRelMu.Lock()
- defer c.outRelMu.Unlock()
+ c.ping.Reset(PingTimeout)
+ if atomic.LoadUint32(&c.closing) == 1 {
+ c.ping.Stop()
+ }
- sn := c.outRelSN
- for ; sn-c.outRelWin >= 0x8000; c.outRelWin++ {
- if ack, ok := c.ackChans.Load(c.outRelWin); ok {
- <-ack.(chan struct{})
+ return nil, nil
}
}
- c.outRelSN++
-
- rwack := make(chan struct{}) // close-only
- c.ackChans.Store(sn, rwack)
- ack = rwack
-
- data := make([]byte, RelHdrSize+len(pkt.Data))
- data[0] = uint8(rawTypeRel)
- be.PutUint16(data[1:3], uint16(sn))
- copy(data[RelHdrSize:], pkt.Data)
- rel := rawPkt{
- Data: data,
- ChNo: pkt.ChNo,
- Unrel: true,
- }
-
- if _, err := p.sendRaw(rel); err != nil {
- c.ackChans.Delete(sn)
- return nil, err
- }
-
- go func() {
- for {
- select {
- case <-time.After(500 * time.Millisecond):
- if _, err := p.sendRaw(rel); err != nil {
- if errors.Is(err, net.ErrClosed) {
- return
- }
- p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err)
+ pi.Unrel = true
+ var snBuf []byte
+ send := c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawRel)
+ snBuf = buf[1:3]
+ return 3 + read(buf[3:])
+ }, pi)
+
+ return func() (<-chan struct{}, error) {
+ ch := &c.chans[pi.Channel]
+
+ ch.outRelMu.Lock()
+ defer ch.outRelMu.Unlock()
+
+ sn := ch.outRelSN
+ be.PutUint16(snBuf, uint16(sn))
+ for ; sn-ch.outRelWin >= 0x8000; ch.outRelWin++ {
+ if ack, ok := ch.ackChans.Load(ch.outRelWin); ok {
+ select {
+ case <-ack.(chan struct{}):
+ case <-c.Closed():
}
- case <-ack:
- return
- case <-p.Disco():
- return
}
}
- }()
-
- return ack, nil
-}
-
-// SendDisco sends a disconnect packet to the Peer but does not close it.
-// It returns a channel that's closed when it's acked or an error.
-// The ack channel is nil if unrel is true.
-func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) {
- return p.sendRaw(rawPkt{
- Data: []byte{uint8(rawTypeCtl), uint8(ctlDisco)},
- ChNo: chno,
- Unrel: unrel,
- })
-}
-func split(data []byte, chunksize int) [][]byte {
- chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize)
+ ack := make(chan struct{})
+ ch.ackChans.Store(sn, ack)
- for i := 0; i < len(data); i += chunksize {
- end := i + chunksize
- if end > len(data) {
- end = len(data)
+ if _, err := send(); err != nil {
+ if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
+ close(ack.(chan struct{}))
+ }
+ return nil, err
}
+ ch.outRelSN++
+
+ go func() {
+ t := time.NewTimer(500 * time.Millisecond)
+ defer t.Stop()
+
+ for {
+ select {
+ case <-ack:
+ return
+ case <-t.C:
+ send()
+ t.Reset(500 * time.Millisecond)
+ case <-c.Closed():
+ return
+ }
+ }
+ }()
- chunks = append(chunks, data[i:end])
+ return ack, nil
}
-
- return chunks
}