summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--LICENSE21
-rw-r--r--rudp/listen.go153
-rw-r--r--rudp/net.go44
-rw-r--r--rudp/peer.go193
-rw-r--r--rudp/process.go259
-rw-r--r--rudp/proto.go103
-rw-r--r--rudp/proxy/proxy.go85
-rw-r--r--rudp/send.go248
8 files changed, 1106 insertions, 0 deletions
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..3febf69
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 anon5
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/rudp/listen.go b/rudp/listen.go
new file mode 100644
index 0000000..871a591
--- /dev/null
+++ b/rudp/listen.go
@@ -0,0 +1,153 @@
+package rudp
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+)
+
+type Listener struct {
+ conn net.PacketConn
+
+ clts chan cltPeer
+ errs chan error
+
+ mu sync.Mutex
+ addr2peer map[string]cltPeer
+ id2peer map[PeerID]cltPeer
+ peerid PeerID
+}
+
+// Listen listens for packets on conn until it is closed.
+func Listen(conn net.PacketConn) *Listener {
+ l := &Listener{
+ conn: conn,
+
+ clts: make(chan cltPeer),
+ errs: make(chan error),
+
+ addr2peer: make(map[string]cltPeer),
+ id2peer: make(map[PeerID]cltPeer),
+ }
+
+ pkts := make(chan netPkt)
+ go readNetPkts(l.conn, pkts, l.errs)
+ go func() {
+ for pkt := range pkts {
+ if err := l.processNetPkt(pkt); err != nil {
+ l.errs <- err
+ }
+ }
+
+ close(l.clts)
+
+ for _, clt := range l.addr2peer {
+ clt.Close()
+ }
+ }()
+
+ return l
+}
+
+// Accept waits for and returns a connecting Peer.
+// You should keep calling this until it returns ErrClosed
+// so it doesn't leak a goroutine.
+func (l *Listener) Accept() (*Peer, error) {
+ select {
+ case clt, ok := <-l.clts:
+ if !ok {
+ select {
+ case err := <-l.errs:
+ return nil, err
+ default:
+ return nil, ErrClosed
+ }
+ }
+ close(clt.accepted)
+ return clt.Peer, nil
+ case err := <-l.errs:
+ return nil, err
+ }
+}
+
+// Addr returns the net.PacketConn the Listener is listening on.
+func (l *Listener) Conn() net.PacketConn { return l.conn }
+
+var ErrOutOfPeerIDs = errors.New("out of peer ids")
+
+type cltPeer struct {
+ *Peer
+ pkts chan<- netPkt
+ accepted chan struct{} // close-only
+}
+
+func (l *Listener) processNetPkt(pkt netPkt) error {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ addrstr := pkt.SrcAddr.String()
+
+ clt, ok := l.addr2peer[addrstr]
+ if !ok {
+ prev := l.peerid
+ for l.id2peer[l.peerid].Peer != nil || l.peerid < PeerIDCltMin {
+ if l.peerid == prev-1 {
+ return ErrOutOfPeerIDs
+ }
+ l.peerid++
+ }
+
+ pkts := make(chan netPkt, 256)
+
+ clt = cltPeer{
+ Peer: newPeer(l.conn, pkt.SrcAddr, l.peerid, PeerIDSrv),
+ pkts: pkts,
+ accepted: make(chan struct{}),
+ }
+
+ l.addr2peer[addrstr] = clt
+ l.id2peer[clt.ID()] = clt
+
+ data := make([]byte, 1+1+2)
+ data[0] = uint8(rawTypeCtl)
+ data[1] = uint8(ctlSetPeerID)
+ binary.BigEndian.PutUint16(data[2:4], uint16(clt.ID()))
+ if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
+ return fmt.Errorf("can't set client peer id: %w", err)
+ }
+
+ go func() {
+ select {
+ case l.clts <- clt:
+ case <-clt.Disco():
+ }
+
+ clt.processNetPkts(pkts)
+ }()
+
+ go func() {
+ <-clt.Disco()
+
+ l.mu.Lock()
+ close(pkts)
+ delete(l.addr2peer, addrstr)
+ delete(l.id2peer, clt.ID())
+ l.mu.Unlock()
+ }()
+ }
+
+ select {
+ case <-clt.accepted:
+ clt.pkts <- pkt
+ default:
+ select {
+ case clt.pkts <- pkt:
+ default:
+ return fmt.Errorf("ignoring net pkt from %s because buf is full", addrstr)
+ }
+ }
+
+ return nil
+}
diff --git a/rudp/net.go b/rudp/net.go
new file mode 100644
index 0000000..421a3e7
--- /dev/null
+++ b/rudp/net.go
@@ -0,0 +1,44 @@
+package rudp
+
+import (
+ "errors"
+ "net"
+ "strings"
+)
+
+// TODO: Use net.ErrClosed when Go 1.16 is released.
+var ErrClosed = errors.New("use of closed peer")
+
+/*
+netPkt.Data format (big endian):
+
+ ProtoID
+ Src PeerID
+ ChNo uint8 // Must be < ChannelCount.
+ RawPkt.Data
+*/
+type netPkt struct {
+ SrcAddr net.Addr
+ Data []byte
+}
+
+func readNetPkts(conn net.PacketConn, pkts chan<- netPkt, errs chan<- error) {
+ for {
+ buf := make([]byte, MaxNetPktSize)
+ n, addr, err := conn.ReadFrom(buf)
+ if err != nil {
+ // TODO: Change to this when Go 1.16 is released:
+ // if errors.Is(err, net.ErrClosed) {
+ if strings.Contains(err.Error(), "use of closed network connection") {
+ break
+ }
+
+ errs <- err
+ continue
+ }
+
+ pkts <- netPkt{addr, buf[:n]}
+ }
+
+ close(pkts)
+}
diff --git a/rudp/peer.go b/rudp/peer.go
new file mode 100644
index 0000000..feb0ff9
--- /dev/null
+++ b/rudp/peer.go
@@ -0,0 +1,193 @@
+package rudp
+
+import (
+ "fmt"
+ "net"
+ "sync"
+ "time"
+)
+
+const (
+ // ConnTimeout is the amount of time after no packets being received
+ // from a Peer that it is automatically disconnected.
+ ConnTimeout = 30 * time.Second
+
+ // ConnTimeout is the amount of time after no packets being sent
+ // to a Peer that a CtlPing is automatically sent to prevent timeout.
+ PingTimeout = 5 * time.Second
+)
+
+// A Peer is a connection to a client or server.
+type Peer struct {
+ conn net.PacketConn
+ addr net.Addr
+
+ disco chan struct{} // close-only
+
+ id PeerID
+
+ pkts chan Pkt
+ errs chan error // don't close
+ timedout chan struct{} // close-only
+
+ chans [ChannelCount]pktchan // read/write
+
+ mu sync.RWMutex
+ idOfPeer PeerID
+ timeout *time.Timer
+ ping *time.Ticker
+}
+
+type pktchan struct {
+ // Only accessed by Peer.processRawPkt.
+ insplit map[seqnum][][]byte
+ inrel map[seqnum][]byte
+ inrelsn seqnum
+
+ ackchans sync.Map // map[seqnum]chan struct{}
+
+ outsplitmu sync.Mutex
+ outsplitsn seqnum
+
+ outrelmu sync.Mutex
+ outrelsn seqnum
+ outrelwin seqnum
+}
+
+// Conn returns the net.PacketConn used to communicate with the Peer.
+func (p *Peer) Conn() net.PacketConn { return p.conn }
+
+// Addr returns the address of the Peer.
+func (p *Peer) Addr() net.Addr { return p.addr }
+
+// Disco returns a channel that is closed when the Peer is closed.
+func (p *Peer) Disco() <-chan struct{} { return p.disco }
+
+// ID returns the ID of the Peer.
+func (p *Peer) ID() PeerID { return p.id }
+
+// IsSrv reports whether the Peer is a server.
+func (p *Peer) IsSrv() bool {
+ return p.ID() == PeerIDSrv
+}
+
+// TimedOut reports whether the Peer has timed out.
+func (p *Peer) TimedOut() bool {
+ select {
+ case <-p.timedout:
+ return true
+ default:
+ return false
+ }
+}
+
+// Recv recieves a packet from the Peer.
+// You should keep calling this until it returns ErrClosed
+// so it doesn't leak a goroutine.
+func (p *Peer) Recv() (Pkt, error) {
+ select {
+ case pkt, ok := <-p.pkts:
+ if !ok {
+ select {
+ case err := <-p.errs:
+ return Pkt{}, err
+ default:
+ return Pkt{}, ErrClosed
+ }
+ }
+ return pkt, nil
+ case err := <-p.errs:
+ return Pkt{}, err
+ }
+}
+
+// Close closes the Peer but does not send a disconnect packet.
+func (p *Peer) Close() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ select {
+ case <-p.Disco():
+ return ErrClosed
+ default:
+ }
+
+ p.timeout.Stop()
+ p.timeout = nil
+ p.ping.Stop()
+ p.ping = nil
+
+ close(p.disco)
+
+ return nil
+}
+
+func newPeer(conn net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer {
+ p := &Peer{
+ conn: conn,
+ addr: addr,
+ id: id,
+ idOfPeer: idOfPeer,
+
+ pkts: make(chan Pkt),
+ disco: make(chan struct{}),
+ errs: make(chan error),
+ }
+
+ for i := range p.chans {
+ p.chans[i] = pktchan{
+ insplit: make(map[seqnum][][]byte),
+ inrel: make(map[seqnum][]byte),
+ inrelsn: seqnumInit,
+
+ outsplitsn: seqnumInit,
+ outrelsn: seqnumInit,
+ outrelwin: seqnumInit,
+ }
+ }
+
+ p.timedout = make(chan struct{})
+ p.timeout = time.AfterFunc(ConnTimeout, func() {
+ close(p.timedout)
+
+ p.SendDisco(0, true)
+ p.Close()
+ })
+
+ p.ping = time.NewTicker(PingTimeout)
+ go p.sendPings(p.ping.C)
+
+ return p
+}
+
+func (p *Peer) sendPings(ping <-chan time.Time) {
+ pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}}
+
+ for {
+ select {
+ case <-ping:
+ if _, err := p.sendRaw(pkt); err != nil {
+ p.errs <- fmt.Errorf("can't send ping: %w", err)
+ }
+ case <-p.Disco():
+ return
+ }
+ }
+}
+
+// Connect connects to the server on conn
+// and closes conn when the Peer disconnects.
+func Connect(conn net.PacketConn, addr net.Addr) *Peer {
+ srv := newPeer(conn, addr, PeerIDSrv, PeerIDNil)
+
+ pkts := make(chan netPkt)
+ go readNetPkts(conn, pkts, srv.errs)
+ go srv.processNetPkts(pkts)
+
+ go func() {
+ <-srv.Disco()
+ conn.Close()
+ }()
+
+ return srv
+}
diff --git a/rudp/process.go b/rudp/process.go
new file mode 100644
index 0000000..c36af81
--- /dev/null
+++ b/rudp/process.go
@@ -0,0 +1,259 @@
+package rudp
+
+import (
+ "encoding/binary"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// A PktError is an error that occured while processing a packet.
+type PktError struct {
+ Type string // "net", "raw" or "rel".
+ Data []byte
+ Err error
+}
+
+func (e PktError) Error() string {
+ return "error processing " + e.Type + " pkt: " +
+ hex.EncodeToString(e.Data) + ": " +
+ e.Err.Error()
+}
+
+func (e PktError) Unwrap() error { return e.Err }
+
+func (p *Peer) processNetPkts(pkts <-chan netPkt) {
+ for pkt := range pkts {
+ if err := p.processNetPkt(pkt); err != nil {
+ p.errs <- PktError{"net", pkt.Data, err}
+ }
+ }
+
+ close(p.pkts)
+}
+
+// A TrailingDataError reports a packet with trailing data,
+// it doesn't stop a packet from being processed.
+type TrailingDataError []byte
+
+func (e TrailingDataError) Error() string {
+ return "trailing data: " + hex.EncodeToString([]byte(e))
+}
+
+func (p *Peer) processNetPkt(pkt netPkt) (err error) {
+ if pkt.SrcAddr.String() != p.Addr().String() {
+ return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
+ }
+
+ if len(pkt.Data) < MtHdrSize {
+ return io.ErrUnexpectedEOF
+ }
+
+ if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID {
+ return fmt.Errorf("unsupported protocol id: 0x%08x", id)
+ }
+
+ // src PeerID at pkt.Data[4:6]
+
+ chno := pkt.Data[6]
+ if chno >= ChannelCount {
+ return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
+ }
+
+ p.mu.RLock()
+ if p.timeout != nil {
+ p.timeout.Reset(ConnTimeout)
+ }
+ p.mu.RUnlock()
+
+ rpkt := rawPkt{
+ Data: pkt.Data[MtHdrSize:],
+ ChNo: chno,
+ Unrel: true,
+ }
+ if err := p.processRawPkt(rpkt); err != nil {
+ p.errs <- PktError{"raw", rpkt.Data, err}
+ }
+
+ return nil
+}
+
+func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
+ errWrap := func(format string, a ...interface{}) {
+ if err != nil {
+ err = fmt.Errorf(format, append(a, err)...)
+ }
+ }
+
+ c := &p.chans[pkt.ChNo]
+
+ if len(pkt.Data) < 1 {
+ return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
+ }
+ switch t := rawType(pkt.Data[0]); t {
+ case rawTypeCtl:
+ defer errWrap("ctl: %w")
+
+ if len(pkt.Data) < 1+1 {
+ return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
+ }
+ switch ct := ctlType(pkt.Data[1]); ct {
+ case ctlAck:
+ defer errWrap("ack: %w")
+
+ if len(pkt.Data) < 1+1+2 {
+ return io.ErrUnexpectedEOF
+ }
+
+ sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4]))
+
+ if ack, ok := c.ackchans.LoadAndDelete(sn); ok {
+ close(ack.(chan struct{}))
+ }
+
+ if len(pkt.Data) > 1+1+2 {
+ return TrailingDataError(pkt.Data[1+1+2:])
+ }
+ case ctlSetPeerID:
+ defer errWrap("set peer id: %w")
+
+ if len(pkt.Data) < 1+1+2 {
+ return io.ErrUnexpectedEOF
+ }
+
+ // Ensure no concurrent senders while peer id changes.
+ p.mu.Lock()
+ if p.idOfPeer != PeerIDNil {
+ return errors.New("peer id already set")
+ }
+
+ p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4]))
+ p.mu.Unlock()
+
+ if len(pkt.Data) > 1+1+2 {
+ return TrailingDataError(pkt.Data[1+1+2:])
+ }
+ case ctlPing:
+ defer errWrap("ping: %w")
+
+ if len(pkt.Data) > 1+1 {
+ return TrailingDataError(pkt.Data[1+1:])
+ }
+ case ctlDisco:
+ defer errWrap("disco: %w")
+
+ if err := p.Close(); err != nil {
+ return fmt.Errorf("can't close: %w", err)
+ }
+
+ if len(pkt.Data) > 1+1 {
+ return TrailingDataError(pkt.Data[1+1:])
+ }
+ default:
+ return fmt.Errorf("unsupported ctl type: %d", ct)
+ }
+ case rawTypeOrig:
+ p.pkts <- Pkt{
+ Data: pkt.Data[1:],
+ ChNo: pkt.ChNo,
+ Unrel: pkt.Unrel,
+ }
+ case rawTypeSplit:
+ defer errWrap("split: %w")
+
+ if len(pkt.Data) < 1+2+2+2 {
+ return io.ErrUnexpectedEOF
+ }
+
+ sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
+ count := binary.BigEndian.Uint16(pkt.Data[3:5])
+ i := binary.BigEndian.Uint16(pkt.Data[5:7])
+
+ if i >= count {
+ return nil
+ }
+
+ splitpkts := p.chans[pkt.ChNo].insplit
+
+ // Delete old incomplete split packets
+ // so new ones don't get corrupted.
+ delete(splitpkts, sn-0x8000)
+
+ if splitpkts[sn] == nil {
+ splitpkts[sn] = make([][]byte, count)
+ }
+
+ chunks := splitpkts[sn]
+
+ if int(count) != len(chunks) {
+ return fmt.Errorf("chunk count changed on seqnum: %d", sn)
+ }
+
+ chunks[i] = pkt.Data[7:]
+
+ for _, chunk := range chunks {
+ if chunk == nil {
+ return nil
+ }
+ }
+
+ var data []byte
+ for _, chunk := range chunks {
+ data = append(data, chunk...)
+ }
+
+ p.pkts <- Pkt{
+ Data: data,
+ ChNo: pkt.ChNo,
+ Unrel: pkt.Unrel,
+ }
+
+ delete(splitpkts, sn)
+ case rawTypeRel:
+ defer errWrap("rel: %w")
+
+ if len(pkt.Data) < 1+2 {
+ return io.ErrUnexpectedEOF
+ }
+
+ sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
+
+ ackdata := make([]byte, 1+1+2)
+ ackdata[0] = uint8(rawTypeCtl)
+ ackdata[1] = uint8(ctlAck)
+ binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn))
+ ack := rawPkt{
+ Data: ackdata,
+ ChNo: pkt.ChNo,
+ Unrel: true,
+ }
+ if _, err := p.sendRaw(ack); err != nil {
+ return fmt.Errorf("can't ack %d: %w", sn, err)
+ }
+
+ if sn-c.inrelsn >= 0x8000 {
+ return nil // Already received.
+ }
+
+ c.inrel[sn] = pkt.Data[3:]
+
+ for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ {
+ data := c.inrel[c.inrelsn]
+ delete(c.inrel, c.inrelsn)
+
+ rpkt := rawPkt{
+ Data: data,
+ ChNo: pkt.ChNo,
+ Unrel: false,
+ }
+ if err := p.processRawPkt(rpkt); err != nil {
+ p.errs <- PktError{"rel", rpkt.Data, err}
+ }
+ }
+ default:
+ return fmt.Errorf("unsupported pkt type: %d", t)
+ }
+
+ return nil
+}
diff --git a/rudp/proto.go b/rudp/proto.go
new file mode 100644
index 0000000..04176b2
--- /dev/null
+++ b/rudp/proto.go
@@ -0,0 +1,103 @@
+/*
+Package rudp implements the low-level Minetest protocol described at
+https://dev.minetest.net/Network_Protocol#Low-level_protocol.
+
+All exported functions and methods in this package are safe for concurrent use
+by multiple goroutines.
+*/
+package rudp
+
+// protoID must be at the start of every network packet.
+const protoID uint32 = 0x4f457403
+
+// PeerIDs aren't actually used to identify peers, network addresses are,
+// these just exist for backward compatability.
+type PeerID uint16
+
+const (
+ // Used by clients before the server sets their ID.
+ PeerIDNil PeerID = iota
+
+ // The server always has this ID.
+ PeerIDSrv
+
+ // Lowest ID the server can assign to a client.
+ PeerIDCltMin
+)
+
+// ChannelCount is the maximum channel number + 1.
+const ChannelCount = 3
+
+/*
+rawPkt.Data format (big endian):
+
+ rawType
+ switch rawType {
+ case rawTypeCtl:
+ ctlType
+ switch ctlType {
+ case ctlAck:
+ // Tells peer you received a rawTypeRel
+ // and it doesn't need to resend it.
+ seqnum
+ case ctlSetPeerId:
+ // Tells peer to send packets with this Src PeerID.
+ PeerId
+ case ctlPing:
+ // Sent to prevent timeout.
+ case ctlDisco:
+ // Tells peer that you disconnected.
+ }
+ case rawTypeOrig:
+ Pkt.(Data)
+ case rawTypeSplit:
+ // Packet larger than MaxNetPktSize split into smaller packets.
+ // Packets with I >= Count should be ignored.
+ // Once all Count chunks are recieved, they are sorted by I and
+ // concatenated to make a Pkt.(Data).
+ seqnum // Identifies split packet.
+ Count, I uint16
+ Chunk...
+ case rawTypeRel:
+ // Resent until a ctlAck with same seqnum is recieved.
+ // seqnums are sequencial and start at seqnumInit,
+ // These should be processed in seqnum order.
+ seqnum
+ rawPkt.Data
+ }
+*/
+type rawPkt struct {
+ Data []byte
+ ChNo uint8
+ Unrel bool
+}
+
+type rawType uint8
+
+const (
+ rawTypeCtl rawType = iota
+ rawTypeOrig
+ rawTypeSplit
+ rawTypeRel
+)
+
+type ctlType uint8
+
+const (
+ ctlAck ctlType = iota
+ ctlSetPeerID
+ ctlPing
+ ctlDisco
+)
+
+type Pkt struct {
+ Data []byte
+ ChNo uint8
+ Unrel bool
+}
+
+// seqnums are sequence numbers used to maintain reliable packet order
+// and to identify split packets.
+type seqnum uint16
+
+const seqnumInit seqnum = 65500
diff --git a/rudp/proxy/proxy.go b/rudp/proxy/proxy.go
new file mode 100644
index 0000000..6fc14ec
--- /dev/null
+++ b/rudp/proxy/proxy.go
@@ -0,0 +1,85 @@
+/*
+Proxy is a Minetest RUDP proxy server
+supporting multiple concurrent connections.
+
+Usage:
+ proxy dial:port listen:port
+where dial:port is the server address
+and listen:port is the address to listen on.
+*/
+package main
+
+import (
+ "fmt"
+ "log"
+ "net"
+ "os"
+
+ "github.com/anon55555/mt/rudp"
+)
+
+func main() {
+ if len(os.Args) != 3 {
+ fmt.Fprintln(os.Stderr, "usage: proxy dial:port listen:port")
+ os.Exit(1)
+ }
+
+ srvaddr, err := net.ResolveUDPAddr("udp", os.Args[1])
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ lc, err := net.ListenPacket("udp", os.Args[2])
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer lc.Close()
+
+ l := rudp.Listen(lc)
+ for {
+ clt, err := l.Accept()
+ if err != nil {
+ log.Print(err)
+ continue
+ }
+
+ log.Print(clt.Addr(), " connected")
+
+ conn, err := net.DialUDP("udp", nil, srvaddr)
+ if err != nil {
+ log.Print(err)
+ continue
+ }
+ srv := rudp.Connect(conn, conn.RemoteAddr())
+
+ go proxy(clt, srv)
+ go proxy(srv, clt)
+ }
+}
+
+func proxy(src, dest *rudp.Peer) {
+ for {
+ pkt, err := src.Recv()
+ if err != nil {
+ if err == rudp.ErrClosed {
+ msg := src.Addr().String() + " disconnected"
+ if src.TimedOut() {
+ msg += " (timed out)"
+ }
+ log.Print(msg)
+
+ break
+ }
+
+ log.Print(err)
+ continue
+ }
+
+ if _, err := dest.Send(pkt); err != nil {
+ log.Print(err)
+ }
+ }
+
+ dest.SendDisco(0, true)
+ dest.Close()
+}
diff --git a/rudp/send.go b/rudp/send.go
new file mode 100644
index 0000000..3cfcda4
--- /dev/null
+++ b/rudp/send.go
@@ -0,0 +1,248 @@
+package rudp
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "math"
+ "net"
+ "sync"
+ "time"
+)
+
+const (
+ // protoID + src PeerID + channel number
+ MtHdrSize = 4 + 2 + 1
+
+ // rawTypeOrig
+ OrigHdrSize = 1
+
+ // rawTypeSpilt + seqnum + chunk count + chunk number
+ SplitHdrSize = 1 + 2 + 2 + 2
+
+ // rawTypeRel + seqnum
+ RelHdrSize = 1 + 2
+)
+
+const (
+ MaxNetPktSize = 512
+
+ MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize
+ MaxRelRawPktSize = MaxUnrelRawPktSize - RelHdrSize
+
+ MaxRelPktSize = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16
+ MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16
+)
+
+var ErrPktTooBig = errors.New("can't send pkt: too big")
+var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount")
+
+// Send sends a packet to the Peer.
+// It returns a channel that's closed when all chunks are acked or an error.
+// The ack channel is nil if pkt.Unrel is true.
+func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
+ if pkt.ChNo >= ChannelCount {
+ return nil, ErrChNoTooBig
+ }
+
+ hdrsize := MtHdrSize
+ if !pkt.Unrel {
+ hdrsize += RelHdrSize
+ }
+
+ if hdrsize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize {
+ c := &p.chans[pkt.ChNo]
+
+ c.outsplitmu.Lock()
+ sn := c.outsplitsn
+ c.outsplitsn++
+ c.outsplitmu.Unlock()
+
+ chunks := split(pkt.Data, MaxNetPktSize-(hdrsize+SplitHdrSize))
+
+ if len(chunks) > math.MaxUint16 {
+ return nil, ErrPktTooBig
+ }
+
+ var wg sync.WaitGroup
+
+ for i, chunk := range chunks {
+ data := make([]byte, SplitHdrSize+len(chunk))
+ data[0] = uint8(rawTypeSplit)
+ binary.BigEndian.PutUint16(data[1:3], uint16(sn))
+ binary.BigEndian.PutUint16(data[3:5], uint16(len(chunks)))
+ binary.BigEndian.PutUint16(data[5:7], uint16(i))
+ copy(data[SplitHdrSize:], chunk)
+
+ wg.Add(1)
+ ack, err := p.sendRaw(rawPkt{
+ Data: data,
+ ChNo: pkt.ChNo,
+ Unrel: pkt.Unrel,
+ })
+ if err != nil {
+ return nil, err
+ }
+ if !pkt.Unrel {
+ if ack == nil {
+ panic("ack is nil")
+ }
+ go func() {
+ <-ack
+ wg.Done()
+ }()
+ }
+ }
+
+ if pkt.Unrel {
+ return nil, nil
+ } else {
+ ack := make(chan struct{})
+
+ go func() {
+ wg.Wait()
+ close(ack)
+ }()
+
+ return ack, nil
+ }
+ }
+
+ return p.sendRaw(rawPkt{
+ Data: append([]byte{uint8(rawTypeOrig)}, pkt.Data...),
+ ChNo: pkt.ChNo,
+ Unrel: pkt.Unrel,
+ })
+}
+
+// sendRaw sends a raw packet to the Peer.
+func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) {
+ if pkt.ChNo >= ChannelCount {
+ return nil, ErrChNoTooBig
+ }
+
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ select {
+ case <-p.Disco():
+ return nil, ErrClosed
+ default:
+ }
+
+ if !pkt.Unrel {
+ return p.sendRel(pkt)
+ }
+
+ data := make([]byte, MtHdrSize+len(pkt.Data))
+ binary.BigEndian.PutUint32(data[0:4], protoID)
+ binary.BigEndian.PutUint16(data[4:6], uint16(p.idOfPeer))
+ data[6] = pkt.ChNo
+ copy(data[MtHdrSize:], pkt.Data)
+
+ if len(data) > MaxNetPktSize {
+ return nil, ErrPktTooBig
+ }
+
+ _, err = p.Conn().WriteTo(data, p.Addr())
+ if errors.Is(err, net.ErrWriteToConnected) {
+ conn, ok := p.Conn().(net.Conn)
+ if !ok {
+ return nil, err
+ }
+ _, err = conn.Write(data)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ p.ping.Reset(PingTimeout)
+
+ return nil, nil
+}
+
+// sendRel sends a reliable raw packet to the Peer.
+func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) {
+ if pkt.Unrel {
+ panic("mt/rudp: sendRel: pkt.Unrel is true")
+ }
+
+ c := &p.chans[pkt.ChNo]
+
+ c.outrelmu.Lock()
+ defer c.outrelmu.Unlock()
+
+ sn := c.outrelsn
+ for ; sn-c.outrelwin >= 0x8000; c.outrelwin++ {
+ if ack, ok := c.ackchans.Load(c.outrelwin); ok {
+ <-ack.(chan struct{})
+ }
+ }
+ c.outrelsn++
+
+ rwack := make(chan struct{}) // close-only
+ c.ackchans.Store(sn, rwack)
+ ack = rwack
+
+ reldata := make([]byte, RelHdrSize+len(pkt.Data))
+ reldata[0] = uint8(rawTypeRel)
+ binary.BigEndian.PutUint16(reldata[1:3], uint16(sn))
+ copy(reldata[RelHdrSize:], pkt.Data)
+ relpkt := rawPkt{
+ Data: reldata,
+ ChNo: pkt.ChNo,
+ Unrel: true,
+ }
+
+ if _, err := p.sendRaw(relpkt); err != nil {
+ c.ackchans.Delete(sn)
+
+ return nil, err
+ }
+
+ go func() {
+ resend := time.NewTicker(500 * time.Millisecond)
+ defer resend.Stop()
+
+ for {
+ select {
+ case <-resend.C:
+ if _, err := p.sendRaw(relpkt); err != nil {
+ p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err)
+ }
+ case <-ack:
+ return
+ case <-p.Disco():
+ return
+ }
+ }
+ }()
+
+ return ack, nil
+}
+
+// SendDisco sends a disconnect packet to the Peer but does not close it.
+// It returns a channel that's closed when it's acked or an error.
+// The ack channel is nil if unrel is true.
+func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) {
+ return p.sendRaw(rawPkt{
+ Data: []byte{uint8(rawTypeCtl), uint8(ctlDisco)},
+ ChNo: chno,
+ Unrel: unrel,
+ })
+}
+
+func split(data []byte, chunksize int) [][]byte {
+ chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize)
+
+ for i := 0; i < len(data); i += chunksize {
+ end := i + chunksize
+ if end > len(data) {
+ end = len(data)
+ }
+
+ chunks = append(chunks, data[i:end])
+ }
+
+ return chunks
+}