summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rudp/conn.go173
-rw-r--r--rudp/connect.go18
-rw-r--r--rudp/listen.go270
-rw-r--r--rudp/net.go41
-rw-r--r--rudp/peer.go207
-rw-r--r--rudp/process.go253
-rw-r--r--rudp/proxy/proxy.go39
-rw-r--r--rudp/recv.go259
-rw-r--r--rudp/rudp.go129
-rw-r--r--rudp/send.go358
-rw-r--r--rudp/udp.go13
11 files changed, 878 insertions, 882 deletions
diff --git a/rudp/conn.go b/rudp/conn.go
new file mode 100644
index 0000000..7e241a8
--- /dev/null
+++ b/rudp/conn.go
@@ -0,0 +1,173 @@
+package rudp
+
+import (
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// A Conn is a connection to a client or server.
+// All Conn's methods are safe for concurrent use.
+type Conn struct {
+ udpConn udpConn
+
+ id PeerID
+
+ pkts chan Pkt
+ errs chan error
+
+ timeout *time.Timer
+ ping *time.Ticker
+
+ closing uint32
+ closed chan struct{}
+ err error
+
+ mu sync.RWMutex
+ remoteID PeerID
+
+ chans [ChannelCount]pktChan // read/write
+}
+
+// ID returns the PeerID of the Conn.
+func (c *Conn) ID() PeerID { return c.id }
+
+// IsSrv reports whether the Conn is a connection to a server.
+func (c *Conn) IsSrv() bool { return c.ID() == PeerIDSrv }
+
+// Closed returns a channel which is closed when the Conn is closed.
+func (c *Conn) Closed() <-chan struct{} { return c.closed }
+
+// WhyClosed returns the error that caused the Conn to be closed or nil
+// if the Conn was closed using the Close method or by the peer.
+// WhyClosed returns nil if the Conn is not closed.
+func (c *Conn) WhyClosed() error {
+ select {
+ case <-c.Closed():
+ return c.err
+ default:
+ return nil
+ }
+}
+
+// LocalAddr returns the local network address.
+func (c *Conn) LocalAddr() net.Addr { return c.udpConn.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *Conn) RemoteAddr() net.Addr { return c.udpConn.RemoteAddr() }
+
+type pktChan struct {
+ // Only accessed by Conn.recvUDPPkts goroutine.
+ inRels *[0x8000][]byte
+ inRelSN seqnum
+ sendAck func() (<-chan struct{}, error)
+ ackBuf []byte
+
+ inSplitsMu sync.RWMutex
+ inSplits map[seqnum]*inSplit
+
+ ackChans sync.Map // map[seqnum]chan struct{}
+
+ outSplitMu sync.Mutex
+ outSplitSN seqnum
+
+ outRelMu sync.Mutex
+ outRelSN seqnum
+ outRelWin seqnum
+}
+
+type inSplit struct {
+ chunks [][]byte
+ got int
+ timeout *time.Timer
+}
+
+// Close closes the Conn.
+// Any blocked Send or Recv calls will return net.ErrClosed.
+func (c *Conn) Close() error {
+ return c.closeDisco(nil)
+}
+
+func (c *Conn) closeDisco(err error) error {
+ c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawCtl)
+ buf[1] = uint8(ctlDisco)
+ return 2
+ }, PktInfo{Unrel: true})()
+
+ return c.close(err)
+}
+
+func (c *Conn) close(err error) error {
+ if atomic.SwapUint32(&c.closing, 1) == 1 {
+ return net.ErrClosed
+ }
+
+ c.timeout.Stop()
+ c.ping.Stop()
+
+ c.err = err
+ defer close(c.closed)
+
+ return c.udpConn.Close()
+}
+
+func newConn(uc udpConn, id, remoteID PeerID) *Conn {
+ var c *Conn
+ c = &Conn{
+ udpConn: uc,
+
+ id: id,
+
+ pkts: make(chan Pkt),
+ errs: make(chan error),
+
+ timeout: time.AfterFunc(ConnTimeout, func() {
+ c.closeDisco(ErrTimedOut)
+ }),
+ ping: time.NewTicker(PingTimeout),
+
+ closed: make(chan struct{}),
+
+ remoteID: remoteID,
+ }
+
+ for i := range c.chans {
+ c.chans[i] = pktChan{
+ inRels: new([0x8000][]byte),
+ inRelSN: initSeqnum,
+
+ inSplits: make(map[seqnum]*inSplit),
+
+ outSplitSN: initSeqnum,
+
+ outRelSN: initSeqnum,
+ outRelWin: initSeqnum,
+ }
+ }
+
+ c.newAckBuf()
+
+ go c.sendPings(c.ping.C)
+ go c.recvUDPPkts()
+
+ return c
+}
+
+func (c *Conn) sendPings(ping <-chan time.Time) {
+ send := c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawCtl)
+ buf[1] = uint8(ctlPing)
+ return 2
+ }, PktInfo{})
+
+ for {
+ select {
+ case <-ping:
+ send()
+ case <-c.Closed():
+ return
+ }
+ }
+}
diff --git a/rudp/connect.go b/rudp/connect.go
new file mode 100644
index 0000000..548ab15
--- /dev/null
+++ b/rudp/connect.go
@@ -0,0 +1,18 @@
+package rudp
+
+import "net"
+
+type udpSrv struct {
+ net.Conn
+}
+
+func (us udpSrv) recvUDP() ([]byte, error) {
+ buf := make([]byte, maxUDPPktSize)
+ n, err := us.Read(buf)
+ return buf[:n], err
+}
+
+// Connect returns a Conn connected to conn.
+func Connect(conn net.Conn) *Conn {
+ return newConn(udpSrv{conn}, PeerIDSrv, PeerIDNil)
+}
diff --git a/rudp/listen.go b/rudp/listen.go
index 5b7154a..e1cacf4 100644
--- a/rudp/listen.go
+++ b/rudp/listen.go
@@ -2,155 +2,213 @@ package rudp
import (
"errors"
- "fmt"
"net"
"sync"
)
+func tryClose(ch chan struct{}) (ok bool) {
+ defer func() { recover() }()
+ close(ch)
+ return true
+}
+
+type udpClt struct {
+ l *Listener
+ id PeerID
+ addr net.Addr
+ pkts chan []byte
+ closed chan struct{}
+}
+
+func (c *udpClt) mkConn() {
+ conn := newConn(c, c.id, PeerIDSrv)
+ go func() {
+ <-conn.Closed()
+ c.l.wg.Done()
+ }()
+ conn.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawCtl)
+ buf[1] = uint8(ctlSetPeerID)
+ be.PutUint16(buf[2:4], uint16(conn.ID()))
+ return 4
+ }, PktInfo{})()
+ select {
+ case c.l.conns <- conn:
+ case <-c.l.closed:
+ conn.Close()
+ }
+}
+
+func (c *udpClt) Write(pkt []byte) (int, error) {
+ select {
+ case <-c.closed:
+ return 0, net.ErrClosed
+ default:
+ }
+
+ return c.l.pc.WriteTo(pkt, c.addr)
+}
+
+func (c *udpClt) recvUDP() ([]byte, error) {
+ select {
+ case pkt := <-c.pkts:
+ return pkt, nil
+ case <-c.closed:
+ return nil, net.ErrClosed
+ }
+}
+
+func (c *udpClt) Close() error {
+ if !tryClose(c.closed) {
+ return net.ErrClosed
+ }
+
+ c.l.mu.Lock()
+ defer c.l.mu.Unlock()
+
+ delete(c.l.ids, c.id)
+ delete(c.l.clts, c.addr.String())
+
+ return nil
+}
+
+func (c *udpClt) LocalAddr() net.Addr { return c.l.pc.LocalAddr() }
+func (c *udpClt) RemoteAddr() net.Addr { return c.addr }
+
+// All Listener's methods are safe for concurrent use.
type Listener struct {
- conn net.PacketConn
+ pc net.PacketConn
- clts chan cltPeer
- errs chan error
+ peerID PeerID
+ conns chan *Conn
+ errs chan error
+ closed chan struct{}
+ wg sync.WaitGroup
- mu sync.Mutex
- addr2peer map[string]cltPeer
- id2peer map[PeerID]cltPeer
- peerID PeerID
+ mu sync.RWMutex
+ ids map[PeerID]bool
+ clts map[string]*udpClt
}
-// Listen listens for packets on conn until it is closed.
-func Listen(conn net.PacketConn) *Listener {
+// Listen listens for connections on pc, pc is closed once the returned Listener
+// and all Conns connected through it are closed.
+func Listen(pc net.PacketConn) *Listener {
l := &Listener{
- conn: conn,
+ pc: pc,
- clts: make(chan cltPeer),
- errs: make(chan error),
+ conns: make(chan *Conn),
+ closed: make(chan struct{}),
- addr2peer: make(map[string]cltPeer),
- id2peer: make(map[PeerID]cltPeer),
+ ids: make(map[PeerID]bool),
+ clts: make(map[string]*udpClt),
}
- pkts := make(chan netPkt)
- go readNetPkts(l.conn, pkts, l.errs)
go func() {
- for pkt := range pkts {
- if err := l.processNetPkt(pkt); err != nil {
- l.errs <- err
+ for {
+ if err := l.processNetPkt(); err != nil {
+ if errors.Is(err, net.ErrClosed) {
+ break
+ }
+ select {
+ case l.errs <- err:
+ case <-l.closed:
+ }
}
}
-
- close(l.clts)
-
- for _, clt := range l.addr2peer {
- clt.Close()
- }
}()
return l
}
-// Accept waits for and returns a connecting Peer.
-// You should keep calling this until it returns net.ErrClosed
-// so it doesn't leak a goroutine.
-func (l *Listener) Accept() (*Peer, error) {
+// Accept waits for and returns the next incoming Conn or an error.
+func (l *Listener) Accept() (*Conn, error) {
select {
- case clt, ok := <-l.clts:
- if !ok {
- select {
- case err := <-l.errs:
- return nil, err
- default:
- return nil, net.ErrClosed
- }
- }
- close(clt.accepted)
- return clt.Peer, nil
+ case c := <-l.conns:
+ return c, nil
case err := <-l.errs:
return nil, err
+ case <-l.closed:
+ return nil, net.ErrClosed
}
}
-// Addr returns the net.PacketConn the Listener is listening on.
-func (l *Listener) Conn() net.PacketConn { return l.conn }
+// Close makes the Listener stop listening for new Conns.
+// Blocked Accept calls will return net.ErrClosed.
+// Already Accepted Conns are not closed.
+func (l *Listener) Close() error {
+ if !tryClose(l.closed) {
+ return net.ErrClosed
+ }
-var ErrOutOfPeerIDs = errors.New("out of peer ids")
+ go func() {
+ l.wg.Wait()
+ l.pc.Close()
+ }()
-type cltPeer struct {
- *Peer
- pkts chan<- netPkt
- accepted chan struct{} // close-only
+ return nil
}
-func (l *Listener) processNetPkt(pkt netPkt) error {
- l.mu.Lock()
- defer l.mu.Unlock()
-
- addrstr := pkt.SrcAddr.String()
+// Addr returns the Listener's network address.
+func (l *Listener) Addr() net.Addr { return l.pc.LocalAddr() }
- clt, ok := l.addr2peer[addrstr]
- if !ok {
- prev := l.peerID
- for l.id2peer[l.peerID].Peer != nil || l.peerID < PeerIDCltMin {
- if l.peerID == prev-1 {
- return ErrOutOfPeerIDs
- }
- l.peerID++
- }
+var ErrOutOfPeerIDs = errors.New("out of peer ids")
- pkts := make(chan netPkt, 256)
+func (l *Listener) processNetPkt() error {
+ buf := make([]byte, maxUDPPktSize)
+ n, addr, err := l.pc.ReadFrom(buf)
+ if err != nil {
+ return err
+ }
- clt = cltPeer{
- Peer: newPeer(l.conn, pkt.SrcAddr, l.peerID, PeerIDSrv),
- pkts: pkts,
- accepted: make(chan struct{}),
+ l.mu.RLock()
+ clt, ok := l.clts[addr.String()]
+ l.mu.RUnlock()
+ if !ok {
+ select {
+ case <-l.closed:
+ return nil
+ default:
}
- l.addr2peer[addrstr] = clt
- l.id2peer[clt.ID()] = clt
-
- data := make([]byte, 1+1+2)
- data[0] = uint8(rawTypeCtl)
- data[1] = uint8(ctlSetPeerID)
- be.PutUint16(data[2:4], uint16(clt.ID()))
- if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
- if errors.Is(err, net.ErrClosed) {
- return nil
- }
- return fmt.Errorf("can't set client peer id: %w", err)
+ clt, err = l.add(addr)
+ if err != nil {
+ return err
}
+ }
- go func() {
- select {
- case l.clts <- clt:
- case <-clt.Disco():
- }
-
- clt.processNetPkts(pkts)
- }()
+ select {
+ case clt.pkts <- buf[:n]:
+ case <-clt.closed:
+ }
- go func() {
- <-clt.Disco()
+ return nil
+}
- l.mu.Lock()
- close(pkts)
- delete(l.addr2peer, addrstr)
- delete(l.id2peer, clt.ID())
- l.mu.Unlock()
- }()
- }
+func (l *Listener) add(addr net.Addr) (*udpClt, error) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
- select {
- case <-clt.accepted:
- clt.pkts <- pkt
- default:
- select {
- case clt.pkts <- pkt:
- default:
- // It's OK to drop packets if the buffer is full
- // because MT RUDP can cope with packet loss.
+ start := l.peerID
+ l.peerID++
+ for l.peerID < PeerIDCltMin || l.ids[l.peerID] {
+ if l.peerID == start {
+ return nil, ErrOutOfPeerIDs
}
+ l.peerID++
+ }
+ l.ids[l.peerID] = true
+
+ clt := &udpClt{
+ l: l,
+ id: l.peerID,
+ addr: addr,
+ pkts: make(chan []byte),
+ closed: make(chan struct{}),
}
+ l.clts[addr.String()] = clt
- return nil
+ l.wg.Add(1)
+ go clt.mkConn()
+
+ return clt, nil
}
diff --git a/rudp/net.go b/rudp/net.go
deleted file mode 100644
index e2e7289..0000000
--- a/rudp/net.go
+++ /dev/null
@@ -1,41 +0,0 @@
-package rudp
-
-import (
- "errors"
- "net"
-)
-
-// ErrClosed is deprecated, use net.ErrClosed instead.
-var ErrClosed = net.ErrClosed
-
-/*
-netPkt.Data format (big endian):
-
- ProtoID
- Src PeerID
- ChNo uint8 // Must be < ChannelCount.
- RawPkt.Data
-*/
-type netPkt struct {
- SrcAddr net.Addr
- Data []byte
-}
-
-func readNetPkts(conn net.PacketConn, pkts chan<- netPkt, errs chan<- error) {
- for {
- buf := make([]byte, MaxNetPktSize)
- n, addr, err := conn.ReadFrom(buf)
- if err != nil {
- if errors.Is(err, net.ErrClosed) {
- break
- }
-
- errs <- err
- continue
- }
-
- pkts <- netPkt{addr, buf[:n]}
- }
-
- close(pkts)
-}
diff --git a/rudp/peer.go b/rudp/peer.go
deleted file mode 100644
index 791249c..0000000
--- a/rudp/peer.go
+++ /dev/null
@@ -1,207 +0,0 @@
-package rudp
-
-import (
- "errors"
- "fmt"
- "net"
- "sync"
- "time"
-)
-
-const (
- // ConnTimeout is the amount of time after no packets being received
- // from a Peer that it is automatically disconnected.
- ConnTimeout = 30 * time.Second
-
- // ConnTimeout is the amount of time after no packets being sent
- // to a Peer that a CtlPing is automatically sent to prevent timeout.
- PingTimeout = 5 * time.Second
-)
-
-// A Peer is a connection to a client or server.
-type Peer struct {
- pc net.PacketConn
- addr net.Addr
- conn net.Conn
-
- disco chan struct{} // close-only
-
- id PeerID
-
- pkts chan Pkt
- errs chan error // don't close
- timedOut chan struct{} // close-only
-
- chans [ChannelCount]pktchan // read/write
-
- mu sync.RWMutex
- idOfPeer PeerID
- timeout *time.Timer
- ping *time.Ticker
-}
-
-// Conn returns the net.PacketConn used to communicate with the Peer.
-func (p *Peer) Conn() net.PacketConn { return p.pc }
-
-// Addr returns the address of the Peer.
-func (p *Peer) Addr() net.Addr { return p.addr }
-
-// Disco returns a channel that is closed when the Peer is closed.
-func (p *Peer) Disco() <-chan struct{} { return p.disco }
-
-// ID returns the ID of the Peer.
-func (p *Peer) ID() PeerID { return p.id }
-
-// IsSrv reports whether the Peer is a server.
-func (p *Peer) IsSrv() bool {
- return p.ID() == PeerIDSrv
-}
-
-// TimedOut reports whether the Peer has timed out.
-func (p *Peer) TimedOut() bool {
- select {
- case <-p.timedOut:
- return true
- default:
- return false
- }
-}
-
-type inSplit struct {
- chunks [][]byte
- size, got int
-}
-
-type pktchan struct {
- // Only accessed by Peer.processRawPkt.
- inSplit *[65536]*inSplit
- inRel *[65536][]byte
- inRelSN seqnum
-
- ackChans sync.Map // map[seqnum]chan struct{}
-
- outSplitMu sync.Mutex
- outSplitSN seqnum
-
- outRelMu sync.Mutex
- outRelSN seqnum
- outRelWin seqnum
-}
-
-// Recv recieves a packet from the Peer.
-// You should keep calling this until it returns net.ErrClosed
-// so it doesn't leak a goroutine.
-func (p *Peer) Recv() (Pkt, error) {
- select {
- case pkt, ok := <-p.pkts:
- if !ok {
- select {
- case err := <-p.errs:
- return Pkt{}, err
- default:
- return Pkt{}, net.ErrClosed
- }
- }
- return pkt, nil
- case err := <-p.errs:
- return Pkt{}, err
- }
-}
-
-// Close closes the Peer but does not send a disconnect packet.
-func (p *Peer) Close() error {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- select {
- case <-p.Disco():
- return net.ErrClosed
- default:
- }
-
- p.timeout.Stop()
- p.timeout = nil
- p.ping.Stop()
- p.ping = nil
-
- close(p.disco)
-
- return nil
-}
-
-func newPeer(pc net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer {
- p := &Peer{
- pc: pc,
- addr: addr,
- id: id,
- idOfPeer: idOfPeer,
-
- pkts: make(chan Pkt),
- disco: make(chan struct{}),
- errs: make(chan error),
- }
-
- if conn, ok := pc.(net.Conn); ok && conn.RemoteAddr() != nil {
- p.conn = conn
- }
-
- for i := range p.chans {
- p.chans[i] = pktchan{
- inSplit: new([65536]*inSplit),
- inRel: new([65536][]byte),
- inRelSN: seqnumInit,
-
- outSplitSN: seqnumInit,
- outRelSN: seqnumInit,
- outRelWin: seqnumInit,
- }
- }
-
- p.timedOut = make(chan struct{})
- p.timeout = time.AfterFunc(ConnTimeout, func() {
- close(p.timedOut)
-
- p.SendDisco(0, true)
- p.Close()
- })
-
- p.ping = time.NewTicker(PingTimeout)
- go p.sendPings(p.ping.C)
-
- return p
-}
-
-func (p *Peer) sendPings(ping <-chan time.Time) {
- pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}}
-
- for {
- select {
- case <-ping:
- if _, err := p.sendRaw(pkt); err != nil {
- if errors.Is(err, net.ErrClosed) {
- return
- }
- p.errs <- fmt.Errorf("can't send ping: %w", err)
- }
- case <-p.Disco():
- return
- }
- }
-}
-
-// Connect connects to addr using pc
-// and closes pc when the returned *Peer disconnects.
-func Connect(pc net.PacketConn, addr net.Addr) *Peer {
- srv := newPeer(pc, addr, PeerIDSrv, PeerIDNil)
-
- pkts := make(chan netPkt)
- go readNetPkts(pc, pkts, srv.errs)
- go srv.processNetPkts(pkts)
-
- go func() {
- <-srv.Disco()
- pc.Close()
- }()
-
- return srv
-}
diff --git a/rudp/process.go b/rudp/process.go
deleted file mode 100644
index 7238fe5..0000000
--- a/rudp/process.go
+++ /dev/null
@@ -1,253 +0,0 @@
-package rudp
-
-import (
- "errors"
- "fmt"
- "io"
- "net"
-)
-
-// A PktError is an error that occured while processing a packet.
-type PktError struct {
- Type string // "net", "raw" or "rel".
- Data []byte
- Err error
-}
-
-func (e PktError) Error() string {
- return fmt.Sprintf("error processing %s pkt: %x: %v", e.Type, e.Data, e.Err)
-}
-
-func (e PktError) Unwrap() error { return e.Err }
-
-func (p *Peer) processNetPkts(pkts <-chan netPkt) {
- for pkt := range pkts {
- if err := p.processNetPkt(pkt); err != nil {
- p.errs <- PktError{"net", pkt.Data, err}
- }
- }
-
- close(p.pkts)
-}
-
-// A TrailingDataError reports a packet with trailing data,
-// it doesn't stop a packet from being processed.
-type TrailingDataError []byte
-
-func (e TrailingDataError) Error() string {
- return fmt.Sprintf("trailing data: %x", []byte(e))
-}
-
-func (p *Peer) processNetPkt(pkt netPkt) (err error) {
- if pkt.SrcAddr.String() != p.Addr().String() {
- return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
- }
-
- if len(pkt.Data) < MtHdrSize {
- return io.ErrUnexpectedEOF
- }
-
- if id := be.Uint32(pkt.Data[0:4]); id != protoID {
- return fmt.Errorf("unsupported protocol id: 0x%08x", id)
- }
-
- // src PeerID at pkt.Data[4:6]
-
- chno := pkt.Data[6]
- if chno >= ChannelCount {
- return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
- }
-
- p.mu.RLock()
- if p.timeout != nil {
- p.timeout.Reset(ConnTimeout)
- }
- p.mu.RUnlock()
-
- rpkt := rawPkt{
- Data: pkt.Data[MtHdrSize:],
- ChNo: chno,
- Unrel: true,
- }
- if err := p.processRawPkt(rpkt); err != nil {
- p.errs <- PktError{"raw", rpkt.Data, err}
- }
-
- return nil
-}
-
-func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
- errWrap := func(format string, a ...interface{}) {
- if err != nil {
- err = fmt.Errorf(format, append(a, err)...)
- }
- }
-
- c := &p.chans[pkt.ChNo]
-
- if len(pkt.Data) < 1 {
- return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
- }
- switch t := rawType(pkt.Data[0]); t {
- case rawTypeCtl:
- defer errWrap("ctl: %w")
-
- if len(pkt.Data) < 1+1 {
- return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
- }
- switch ct := ctlType(pkt.Data[1]); ct {
- case ctlAck:
- defer errWrap("ack: %w")
-
- if len(pkt.Data) < 1+1+2 {
- return io.ErrUnexpectedEOF
- }
-
- sn := seqnum(be.Uint16(pkt.Data[2:4]))
-
- if ack, ok := c.ackChans.LoadAndDelete(sn); ok {
- close(ack.(chan struct{}))
- }
-
- if len(pkt.Data) > 1+1+2 {
- return TrailingDataError(pkt.Data[1+1+2:])
- }
- case ctlSetPeerID:
- defer errWrap("set peer id: %w")
-
- if len(pkt.Data) < 1+1+2 {
- return io.ErrUnexpectedEOF
- }
-
- // Ensure no concurrent senders while peer id changes.
- p.mu.Lock()
- if p.idOfPeer != PeerIDNil {
- return errors.New("peer id already set")
- }
-
- p.idOfPeer = PeerID(be.Uint16(pkt.Data[2:4]))
- p.mu.Unlock()
-
- if len(pkt.Data) > 1+1+2 {
- return TrailingDataError(pkt.Data[1+1+2:])
- }
- case ctlPing:
- defer errWrap("ping: %w")
-
- if len(pkt.Data) > 1+1 {
- return TrailingDataError(pkt.Data[1+1:])
- }
- case ctlDisco:
- defer errWrap("disco: %w")
-
- p.Close()
-
- if len(pkt.Data) > 1+1 {
- return TrailingDataError(pkt.Data[1+1:])
- }
- default:
- return fmt.Errorf("unsupported ctl type: %d", ct)
- }
- case rawTypeOrig:
- p.pkts <- Pkt{
- Data: pkt.Data[1:],
- ChNo: pkt.ChNo,
- Unrel: pkt.Unrel,
- }
- case rawTypeSplit:
- defer errWrap("split: %w")
-
- if len(pkt.Data) < 1+2+2+2 {
- return io.ErrUnexpectedEOF
- }
-
- sn := seqnum(be.Uint16(pkt.Data[1:3]))
- count := be.Uint16(pkt.Data[3:5])
- i := be.Uint16(pkt.Data[5:7])
-
- if i >= count {
- return nil
- }
-
- splits := p.chans[pkt.ChNo].inSplit
-
- // Delete old incomplete split packets
- // so new ones don't get corrupted.
- splits[sn-0x8000] = nil
-
- if splits[sn] == nil {
- splits[sn] = &inSplit{chunks: make([][]byte, count)}
- }
-
- s := splits[sn]
-
- if int(count) != len(s.chunks) {
- return fmt.Errorf("chunk count changed on split packet: %d", sn)
- }
-
- s.chunks[i] = pkt.Data[7:]
- s.size += len(s.chunks[i])
- s.got++
-
- if s.got == len(s.chunks) {
- data := make([]byte, 0, s.size)
- for _, chunk := range s.chunks {
- data = append(data, chunk...)
- }
-
- p.pkts <- Pkt{
- Data: data,
- ChNo: pkt.ChNo,
- Unrel: pkt.Unrel,
- }
-
- splits[sn] = nil
- }
- case rawTypeRel:
- defer errWrap("rel: %w")
-
- if len(pkt.Data) < 1+2 {
- return io.ErrUnexpectedEOF
- }
-
- sn := seqnum(be.Uint16(pkt.Data[1:3]))
-
- ack := make([]byte, 1+1+2)
- ack[0] = uint8(rawTypeCtl)
- ack[1] = uint8(ctlAck)
- be.PutUint16(ack[2:4], uint16(sn))
- if _, err := p.sendRaw(rawPkt{
- Data: ack,
- ChNo: pkt.ChNo,
- Unrel: true,
- }); err != nil {
- if errors.Is(err, net.ErrClosed) {
- return nil
- }
- return fmt.Errorf("can't ack %d: %w", sn, err)
- }
-
- if sn-c.inRelSN >= 0x8000 {
- return nil // Already received.
- }
-
- c.inRel[sn] = pkt.Data[3:]
-
- for ; c.inRel[c.inRelSN] != nil; c.inRelSN++ {
- rpkt := rawPkt{
- Data: c.inRel[c.inRelSN],
- ChNo: pkt.ChNo,
- Unrel: false,
- }
- c.inRel[c.inRelSN] = nil
-
- if err := p.processRawPkt(rpkt); err != nil {
- p.errs <- PktError{"rel", rpkt.Data, err}
- }
- }
- default:
- return fmt.Errorf("unsupported pkt type: %d", t)
- }
-
- return nil
-}
diff --git a/rudp/proxy/proxy.go b/rudp/proxy/proxy.go
index a80b448..b8ca9f4 100644
--- a/rudp/proxy/proxy.go
+++ b/rudp/proxy/proxy.go
@@ -25,62 +25,55 @@ func main() {
os.Exit(1)
}
- srvaddr, err := net.ResolveUDPAddr("udp", os.Args[1])
+ pc, err := net.ListenPacket("udp", os.Args[2])
if err != nil {
log.Fatal(err)
}
+ defer pc.Close()
- lc, err := net.ListenPacket("udp", os.Args[2])
- if err != nil {
- log.Fatal(err)
- }
- defer lc.Close()
-
- l := rudp.Listen(lc)
+ l := rudp.Listen(pc)
for {
clt, err := l.Accept()
if err != nil {
- log.Print(err)
+ log.Print("accept: ", err)
continue
}
- log.Print(clt.Addr(), " connected")
+ log.Print(clt.ID(), ": connected")
- conn, err := net.DialUDP("udp", nil, srvaddr)
+ conn, err := net.Dial("udp", os.Args[1])
if err != nil {
log.Print(err)
continue
}
- srv := rudp.Connect(conn, conn.RemoteAddr())
+ srv := rudp.Connect(conn)
go proxy(clt, srv)
go proxy(srv, clt)
}
}
-func proxy(src, dest *rudp.Peer) {
+func proxy(src, dest *rudp.Conn) {
+ s := fmt.Sprint(src.ID(), " (", src.RemoteAddr(), "): ")
+
for {
pkt, err := src.Recv()
if err != nil {
if errors.Is(err, net.ErrClosed) {
- msg := src.Addr().String() + " disconnected"
- if src.TimedOut() {
- msg += " (timed out)"
+ if err := src.WhyClosed(); err != nil {
+ log.Print(s, "disconnected: ", err)
+ } else {
+ log.Print(s, "disconnected")
}
- log.Print(msg)
-
break
}
- log.Print(err)
+ log.Print(s, err)
continue
}
- if _, err := dest.Send(pkt); err != nil {
- log.Print(err)
- }
+ dest.Send(pkt)
}
- dest.SendDisco(0, true)
dest.Close()
}
diff --git a/rudp/recv.go b/rudp/recv.go
new file mode 100644
index 0000000..f5ac236
--- /dev/null
+++ b/rudp/recv.go
@@ -0,0 +1,259 @@
+package rudp
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "time"
+)
+
+// Recv receives a Pkt from the Conn.
+func (c *Conn) Recv() (Pkt, error) {
+ select {
+ case pkt := <-c.pkts:
+ return pkt, nil
+ case err := <-c.errs:
+ return Pkt{}, err
+ case <-c.Closed():
+ return Pkt{}, net.ErrClosed
+ }
+}
+
+func (c *Conn) gotPkt(pkt Pkt) {
+ select {
+ case c.pkts <- pkt:
+ case <-c.Closed():
+ }
+}
+
+func (c *Conn) gotErr(kind string, data []byte, err error) {
+ select {
+ case c.errs <- fmt.Errorf("%s: %x: %w", kind, data, err):
+ case <-c.Closed():
+ }
+}
+
+func (c *Conn) recvUDPPkts() {
+ for {
+ pkt, err := c.udpConn.recvUDP()
+ if err != nil {
+ c.closeDisco(err)
+ break
+ }
+
+ if err := c.processUDPPkt(pkt); err != nil {
+ c.gotErr("udp", pkt, err)
+ }
+ }
+}
+
+func (c *Conn) processUDPPkt(pkt []byte) error {
+ if c.timeout.Stop() {
+ c.timeout.Reset(ConnTimeout)
+ }
+
+ if len(pkt) < 6 {
+ return io.ErrUnexpectedEOF
+ }
+
+ if id := be.Uint32(pkt[0:4]); id != protoID {
+ return fmt.Errorf("unsupported protocol id: 0x%08x", id)
+ }
+
+ ch := Channel(pkt[6])
+ if ch >= ChannelCount {
+ return TooBigChError(ch)
+ }
+
+ if err := c.processRawPkt(pkt[7:], PktInfo{Channel: ch, Unrel: true}); err != nil {
+ c.gotErr("raw", pkt, err)
+ }
+
+ return nil
+}
+
+// A TrailingDataError reports trailing data after a packet.
+type TrailingDataError []byte
+
+func (e TrailingDataError) Error() string {
+ return fmt.Sprintf("trailing data: %x", []byte(e))
+}
+
+func (c *Conn) processRawPkt(data []byte, pi PktInfo) (err error) {
+ errWrap := func(format string, a ...interface{}) {
+ if err != nil {
+ err = fmt.Errorf(format+": %w", append(a, err)...)
+ }
+ }
+
+ eof := new(byte)
+ defer func() {
+ switch r := recover(); r {
+ case nil:
+ case eof:
+ err = io.ErrUnexpectedEOF
+ default:
+ panic(r)
+ }
+ }()
+
+ off := 0
+ eat := func(n int) []byte {
+ i := off
+ off += n
+ if i > len(data) {
+ panic(eof)
+ }
+ return data[i:off]
+ }
+
+ ch := &c.chans[pi.Channel]
+
+ switch t := rawType(eat(1)[0]); t {
+ case rawCtl:
+ defer errWrap("ctl")
+
+ switch ct := ctlType(eat(1)[0]); ct {
+ case ctlAck:
+ defer errWrap("ack")
+
+ sn := seqnum(be.Uint16(eat(2)))
+
+ if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
+ close(ack.(chan struct{}))
+ }
+ case ctlSetPeerID:
+ defer errWrap("set peer id")
+
+ c.mu.Lock()
+ if c.remoteID != PeerIDNil {
+ return errors.New("peer id already set")
+ }
+
+ c.remoteID = PeerID(be.Uint16(eat(2)))
+ c.mu.Unlock()
+
+ c.newAckBuf()
+ case ctlPing:
+ defer errWrap("ping")
+ case ctlDisco:
+ defer errWrap("disco")
+
+ c.close(nil)
+ default:
+ return fmt.Errorf("unsupported ctl type: %d", ct)
+ }
+
+ if off < len(data) {
+ return TrailingDataError(data[off:])
+ }
+ case rawOrig:
+ c.gotPkt(Pkt{
+ Reader: bytes.NewReader(data[off:]),
+ PktInfo: pi,
+ })
+ case rawSplit:
+ defer errWrap("split")
+
+ sn := seqnum(be.Uint16(eat(2)))
+ n := be.Uint16(eat(2))
+ i := be.Uint16(eat(2))
+
+ defer errWrap("%d", sn)
+
+ if i >= n {
+ return fmt.Errorf("chunk number (%d) > chunk count (%d)", i, n)
+ }
+
+ ch.inSplitsMu.RLock()
+ s := ch.inSplits[sn]
+ ch.inSplitsMu.RUnlock()
+
+ if s == nil {
+ s = &inSplit{chunks: make([][]byte, n)}
+ if pi.Unrel {
+ s.timeout = time.AfterFunc(ConnTimeout, func() {
+ ch.inSplitsMu.Lock()
+ delete(ch.inSplits, sn)
+ ch.inSplitsMu.Unlock()
+ })
+ }
+
+ ch.inSplitsMu.Lock()
+ ch.inSplits[sn] = s
+ ch.inSplitsMu.Unlock()
+ }
+
+ if int(n) != len(s.chunks) {
+ return fmt.Errorf("chunk count changed from %d to %d", len(s.chunks), n)
+ }
+
+ if s.chunks[i] == nil {
+ s.chunks[i] = data[off:]
+ s.got++
+ }
+
+ if s.got < len(s.chunks) {
+ if s.timeout != nil && s.timeout.Stop() {
+ s.timeout.Reset(ConnTimeout)
+ }
+ return
+ }
+
+ if s.timeout != nil {
+ s.timeout.Stop()
+ }
+
+ ch.inSplitsMu.Lock()
+ delete(ch.inSplits, sn)
+ ch.inSplitsMu.Unlock()
+
+ c.gotPkt(Pkt{
+ Reader: (*net.Buffers)(&s.chunks),
+ PktInfo: pi,
+ })
+ case rawRel:
+ defer errWrap("rel")
+
+ sn := seqnum(be.Uint16(eat(2)))
+
+ defer errWrap("%d", sn)
+
+ be.PutUint16(ch.ackBuf, uint16(sn))
+ ch.sendAck()
+
+ if sn-ch.inRelSN >= 0x8000 {
+ // Already received.
+ return nil
+ }
+
+ ch.inRels[sn&0x7fff] = data[off:]
+
+ i := func() seqnum { return ch.inRelSN & 0x7fff }
+ for ; ch.inRels[i()] != nil; ch.inRelSN++ {
+ data := ch.inRels[i()]
+ ch.inRels[i()] = nil
+ if err := c.processRawPkt(data, PktInfo{Channel: pi.Channel}); err != nil {
+ c.gotErr("rel", data, err)
+ }
+ }
+ default:
+ return fmt.Errorf("unsupported pkt type: %d", t)
+ }
+
+ return nil
+}
+
+func (c *Conn) newAckBuf() {
+ for i := range c.chans {
+ ch := &c.chans[i]
+ ch.sendAck = c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawCtl)
+ buf[1] = uint8(ctlAck)
+ ch.ackBuf = buf[2:4]
+ return 4
+ }, PktInfo{Channel: Channel(i), Unrel: true})
+ }
+}
diff --git a/rudp/rudp.go b/rudp/rudp.go
index 6b96b56..faf67d9 100644
--- a/rudp/rudp.go
+++ b/rudp/rudp.go
@@ -1,21 +1,44 @@
/*
Package rudp implements the low-level Minetest protocol described at
https://dev.minetest.net/Network_Protocol#Low-level_protocol.
-
-All exported functions and methods in this package are safe for concurrent use
-by multiple goroutines.
*/
package rudp
-import "encoding/binary"
+import (
+ "encoding/binary"
+ "errors"
+ "io"
+ "time"
+)
var be = binary.BigEndian
-// protoID must be at the start of every network packet.
+/*
+UDP packet format:
+
+ protoID
+ src PeerID
+ channel uint8
+ rawType...
+*/
+
+var ErrTimedOut = errors.New("timed out")
+
+const (
+ ConnTimeout = 30 * time.Second
+ PingTimeout = 5 * time.Second
+)
+
+const (
+ MaxRelPktSize = 32439825
+ MaxUnrelPktSize = 32636430
+)
+
+// protoID must be at the start of every UDP packet.
const protoID uint32 = 0x4f457403
-// PeerIDs aren't actually used to identify peers, network addresses are,
-// these just exist for backward compatability.
+// PeerIDs aren't actually used to identify peers, IP addresses and ports are,
+// these just exist for backward compatibility.
type PeerID uint16
const (
@@ -29,79 +52,59 @@ const (
PeerIDCltMin
)
-// ChannelCount is the maximum channel number + 1.
-const ChannelCount = 3
-
-/*
-rawPkt.Data format (big endian):
-
- rawType
- switch rawType {
- case rawTypeCtl:
- ctlType
- switch ctlType {
- case ctlAck:
- // Tells peer you received a rawTypeRel
- // and it doesn't need to resend it.
- seqnum
- case ctlSetPeerId:
- // Tells peer to send packets with this Src PeerID.
- PeerId
- case ctlPing:
- // Sent to prevent timeout.
- case ctlDisco:
- // Tells peer that you disconnected.
- }
- case rawTypeOrig:
- Pkt.(Data)
- case rawTypeSplit:
- // Packet larger than MaxNetPktSize split into smaller packets.
- // Packets with I >= Count should be ignored.
- // Once all Count chunks are recieved, they are sorted by I and
- // concatenated to make a Pkt.(Data).
- seqnum // Identifies split packet.
- Count, I uint16
- Chunk...
- case rawTypeRel:
- // Resent until a ctlAck with same seqnum is recieved.
- // seqnums are sequencial and start at seqnumInit,
- // These should be processed in seqnum order.
- seqnum
- rawPkt.Data
- }
-*/
-type rawPkt struct {
- Data []byte
- ChNo uint8
- Unrel bool
-}
-
type rawType uint8
const (
- rawTypeCtl rawType = iota
- rawTypeOrig
- rawTypeSplit
- rawTypeRel
+ rawCtl rawType = iota
+ // ctlType...
+
+ rawOrig
+ // data...
+
+ rawSplit
+ // seqnum
+ // n, i uint16
+ // data...
+
+ rawRel
+ // seqnum
+ // rawType...
)
type ctlType uint8
const (
ctlAck ctlType = iota
+ // seqnum
+
ctlSetPeerID
- ctlPing
+ // PeerID
+
+ ctlPing // Sent to prevent timeout.
+
ctlDisco
)
type Pkt struct {
- Data []byte
- ChNo uint8
+ io.Reader
+ PktInfo
+}
+
+// Reliable packets in a channel are be received in the order they are sent in.
+// A Channel must be less than ChannelCount.
+type Channel uint8
+
+const ChannelCount Channel = 3
+
+type PktInfo struct {
+ Channel
+
+ // Unrel (unreliable) packets may be dropped, duplicated or reordered.
Unrel bool
}
// seqnums are sequence numbers used to maintain reliable packet order
-// and to identify split packets.
+// and identify split packets.
type seqnum uint16
-const seqnumInit seqnum = 65500
+const initSeqnum seqnum = 65500
diff --git a/rudp/send.go b/rudp/send.go
index c43d056..5522ee0 100644
--- a/rudp/send.go
+++ b/rudp/send.go
@@ -1,241 +1,221 @@
package rudp
import (
+ "bytes"
"errors"
"fmt"
- "math"
+ "io"
"net"
"sync"
+ "sync/atomic"
"time"
)
-const (
- // protoID + src PeerID + channel number
- MtHdrSize = 4 + 2 + 1
-
- // rawTypeOrig
- OrigHdrSize = 1
-
- // rawTypeSpilt + seqnum + chunk count + chunk number
- SplitHdrSize = 1 + 2 + 2 + 2
-
- // rawTypeRel + seqnum
- RelHdrSize = 1 + 2
-)
-
-const (
- MaxNetPktSize = 512
-
- MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize
- MaxRelRawPktSize = MaxUnrelRawPktSize - RelHdrSize
-
- MaxRelPktSize = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16
- MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16
-)
-
var ErrPktTooBig = errors.New("can't send pkt: too big")
-var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount")
-
-// Send sends a packet to the Peer.
-// It returns a channel that's closed when all chunks are acked or an error.
-// The ack channel is nil if pkt.Unrel is true.
-func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
- if pkt.ChNo >= ChannelCount {
- return nil, ErrChNoTooBig
- }
-
- hdrSize := MtHdrSize
- if !pkt.Unrel {
- hdrSize += RelHdrSize
- }
-
- if hdrSize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize {
- c := &p.chans[pkt.ChNo]
-
- c.outSplitMu.Lock()
- sn := c.outSplitSN
- c.outSplitSN++
- c.outSplitMu.Unlock()
- chunks := split(pkt.Data, MaxNetPktSize-(hdrSize+SplitHdrSize))
+// A TooBigChError reports a Channel greater than or equal to ChannelCount.
+type TooBigChError Channel
- if len(chunks) > math.MaxUint16 {
- return nil, ErrPktTooBig
- }
+func (e TooBigChError) Error() string {
+ return fmt.Sprintf("channel >= ChannelCount (%d): %d", ChannelCount, e)
+}
- var wg sync.WaitGroup
+// Send sends a Pkt to the Conn.
+// Ack is closed when the packet is acknowledged.
+// Ack is nil if pkt.Unrel is true or err != nil.
+func (c *Conn) Send(pkt Pkt) (ack <-chan struct{}, err error) {
+ if pkt.Channel >= ChannelCount {
+ return nil, TooBigChError(pkt.Channel)
+ }
- for i, chunk := range chunks {
- data := make([]byte, SplitHdrSize+len(chunk))
- data[0] = uint8(rawTypeSplit)
- be.PutUint16(data[1:3], uint16(sn))
- be.PutUint16(data[3:5], uint16(len(chunks)))
- be.PutUint16(data[5:7], uint16(i))
- copy(data[SplitHdrSize:], chunk)
+ var e error
+ send := c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawOrig)
- wg.Add(1)
- ack, err := p.sendRaw(rawPkt{
- Data: data,
- ChNo: pkt.ChNo,
- Unrel: pkt.Unrel,
- })
+ nn := 1
+ for nn < len(buf) {
+ n, err := pkt.Read(buf[nn:])
+ nn += n
if err != nil {
- return nil, err
- }
- if !pkt.Unrel {
- go func() {
- <-ack
- wg.Done()
- }()
+ e = err
+ return nn
}
}
- if pkt.Unrel {
- return nil, nil
- } else {
- ack := make(chan struct{})
-
- go func() {
- wg.Wait()
- close(ack)
- }()
+ if _, e = pkt.Read(nil); e != nil {
+ return nn
+ }
- return ack, nil
+ pkt.Reader = io.MultiReader(
+ bytes.NewReader([]byte(buf[1:nn])),
+ pkt.Reader,
+ )
+ return nn
+ }, pkt.PktInfo)
+ if e != nil {
+ if e == io.EOF {
+ return send()
}
+ return nil, e
}
- return p.sendRaw(rawPkt{
- Data: append([]byte{uint8(rawTypeOrig)}, pkt.Data...),
- ChNo: pkt.ChNo,
- Unrel: pkt.Unrel,
- })
-}
+ var (
+ sn seqnum
+ i uint16
+
+ sends []func() (<-chan struct{}, error)
+ )
+
+ for {
+ var (
+ b []byte
+ e error
+ )
+ send := c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawSplit)
+
+ n, err := io.ReadFull(pkt, buf[7:])
+ if err != nil && err != io.ErrUnexpectedEOF {
+ e = err
+ return 0
+ }
-// sendRaw sends a raw packet to the Peer.
-func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) {
- if pkt.ChNo >= ChannelCount {
- return nil, ErrChNoTooBig
- }
+ be.PutUint16(buf[5:7], i)
+ if i++; i == 0 {
+ e = ErrPktTooBig
+ return 0
+ }
- p.mu.RLock()
- defer p.mu.RUnlock()
+ b = buf
+ return 7 + n
+ }, pkt.PktInfo)
+ if e != nil {
+ if e == io.EOF {
+ break
+ }
+ return nil, e
+ }
- select {
- case <-p.Disco():
- return nil, net.ErrClosed
- default:
+ sends = append(sends, func() (<-chan struct{}, error) {
+ be.PutUint16(b[1:3], uint16(sn))
+ be.PutUint16(b[3:5], i)
+ return send()
+ })
}
- if !pkt.Unrel {
- return p.sendRel(pkt)
- }
+ ch := &c.chans[pkt.Channel]
- data := make([]byte, MtHdrSize+len(pkt.Data))
- be.PutUint32(data[0:4], protoID)
- be.PutUint16(data[4:6], uint16(p.idOfPeer))
- data[6] = pkt.ChNo
- copy(data[MtHdrSize:], pkt.Data)
+ ch.outSplitMu.Lock()
+ sn = ch.outSplitSN
+ ch.outSplitSN++
+ ch.outSplitMu.Unlock()
- if len(data) > MaxNetPktSize {
- return nil, ErrPktTooBig
- }
+ var wg sync.WaitGroup
- if p.conn != nil {
- _, err = p.conn.Write(data)
- } else {
- _, err = p.pc.WriteTo(data, p.Addr())
- }
- if err != nil {
- return nil, err
+ for _, send := range sends {
+ ack, err := send()
+ if err != nil {
+ return nil, err
+ }
+ if !pkt.Unrel {
+ wg.Add(1)
+ go func() {
+ <-ack
+ wg.Done()
+ }()
+ }
}
- p.ping.Reset(PingTimeout)
+ if !pkt.Unrel {
+ ack := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(ack)
+ }()
+ return ack, nil
+ }
return nil, nil
}
-// sendRel sends a reliable raw packet to the Peer.
-func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) {
- if pkt.Unrel {
- panic("pkt.Unrel is true")
- }
-
- c := &p.chans[pkt.ChNo]
+func (c *Conn) sendRaw(read func([]byte) int, pi PktInfo) func() (<-chan struct{}, error) {
+ if pi.Unrel {
+ buf := make([]byte, maxUDPPktSize)
+ be.PutUint32(buf[0:4], protoID)
+ c.mu.RLock()
+ be.PutUint16(buf[4:6], uint16(c.remoteID))
+ c.mu.RUnlock()
+ buf[6] = uint8(pi.Channel)
+ buf = buf[:7+read(buf[7:])]
+
+ return func() (<-chan struct{}, error) {
+ if _, err := c.udpConn.Write(buf); err != nil {
+ c.close(err)
+ return nil, net.ErrClosed
+ }
- c.outRelMu.Lock()
- defer c.outRelMu.Unlock()
+ c.ping.Reset(PingTimeout)
+ if atomic.LoadUint32(&c.closing) == 1 {
+ c.ping.Stop()
+ }
- sn := c.outRelSN
- for ; sn-c.outRelWin >= 0x8000; c.outRelWin++ {
- if ack, ok := c.ackChans.Load(c.outRelWin); ok {
- <-ack.(chan struct{})
+ return nil, nil
}
}
- c.outRelSN++
-
- rwack := make(chan struct{}) // close-only
- c.ackChans.Store(sn, rwack)
- ack = rwack
-
- data := make([]byte, RelHdrSize+len(pkt.Data))
- data[0] = uint8(rawTypeRel)
- be.PutUint16(data[1:3], uint16(sn))
- copy(data[RelHdrSize:], pkt.Data)
- rel := rawPkt{
- Data: data,
- ChNo: pkt.ChNo,
- Unrel: true,
- }
-
- if _, err := p.sendRaw(rel); err != nil {
- c.ackChans.Delete(sn)
- return nil, err
- }
-
- go func() {
- for {
- select {
- case <-time.After(500 * time.Millisecond):
- if _, err := p.sendRaw(rel); err != nil {
- if errors.Is(err, net.ErrClosed) {
- return
- }
- p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err)
+ pi.Unrel = true
+ var snBuf []byte
+ send := c.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawRel)
+ snBuf = buf[1:3]
+ return 3 + read(buf[3:])
+ }, pi)
+
+ return func() (<-chan struct{}, error) {
+ ch := &c.chans[pi.Channel]
+
+ ch.outRelMu.Lock()
+ defer ch.outRelMu.Unlock()
+
+ sn := ch.outRelSN
+ be.PutUint16(snBuf, uint16(sn))
+ for ; sn-ch.outRelWin >= 0x8000; ch.outRelWin++ {
+ if ack, ok := ch.ackChans.Load(ch.outRelWin); ok {
+ select {
+ case <-ack.(chan struct{}):
+ case <-c.Closed():
}
- case <-ack:
- return
- case <-p.Disco():
- return
}
}
- }()
-
- return ack, nil
-}
-
-// SendDisco sends a disconnect packet to the Peer but does not close it.
-// It returns a channel that's closed when it's acked or an error.
-// The ack channel is nil if unrel is true.
-func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) {
- return p.sendRaw(rawPkt{
- Data: []byte{uint8(rawTypeCtl), uint8(ctlDisco)},
- ChNo: chno,
- Unrel: unrel,
- })
-}
-func split(data []byte, chunksize int) [][]byte {
- chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize)
+ ack := make(chan struct{})
+ ch.ackChans.Store(sn, ack)
- for i := 0; i < len(data); i += chunksize {
- end := i + chunksize
- if end > len(data) {
- end = len(data)
+ if _, err := send(); err != nil {
+ if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
+ close(ack.(chan struct{}))
+ }
+ return nil, err
}
+ ch.outRelSN++
+
+ go func() {
+ t := time.NewTimer(500 * time.Millisecond)
+ defer t.Stop()
+
+ for {
+ select {
+ case <-ack:
+ return
+ case <-t.C:
+ send()
+ t.Reset(500 * time.Millisecond)
+ case <-c.Closed():
+ return
+ }
+ }
+ }()
- chunks = append(chunks, data[i:end])
+ return ack, nil
}
-
- return chunks
}
diff --git a/rudp/udp.go b/rudp/udp.go
new file mode 100644
index 0000000..503f1d4
--- /dev/null
+++ b/rudp/udp.go
@@ -0,0 +1,13 @@
+package rudp
+
+import "net"
+
+const maxUDPPktSize = 512
+
+type udpConn interface {
+ recvUDP() ([]byte, error)
+ Write([]byte) (int, error)
+ Close() error
+ LocalAddr() net.Addr
+ RemoteAddr() net.Addr
+}