diff options
author | HimbeerserverDE <himbeerserverde@gmail.com> | 2021-08-28 15:06:18 +0200 |
---|---|---|
committer | HimbeerserverDE <himbeerserverde@gmail.com> | 2021-08-28 15:06:18 +0200 |
commit | 4bd36bff71f73425397043f538e819448973928f (patch) | |
tree | 5014f2b4ba06ead584a96cbe6dfad8ebe7267aa9 | |
parent | e45b9940473ad52555296005f4beef22aacc8276 (diff) |
Don't wait for ack if the Peer has disconnected
-rw-r--r-- | README.md | 6 | ||||
-rw-r--r-- | client_conn.go | 134 | ||||
-rw-r--r-- | main.go | 47 | ||||
-rw-r--r-- | server_conn.go | 8 |
4 files changed, 154 insertions, 41 deletions
@@ -8,10 +8,16 @@ Go 1.16 or higher is required. Run to download and compile the project. A MTConnector executable will be created in your $GOBIN directory. ## Usage +### Starting Run `$GOBIN/MTConnector`. The configuration file and other required files are created automatically in the directory the executable (or symlink to said executable) is in, so make sure to move the executable to the desired location or use a symlink. +### Stopping +MTConnector reacts to SIGINT, SIGTERM and SIGHUP. It stops listening +for new connections, kicks all clients, disconnects from all servers +and exits. If some clients aren't responding, MTConnector waits until +they have timed out. ## Configuration The configuration file name and format are described in [doc/config.md](doc/config.md) **All internal servers need to allow empty passwords and must not be reachable from the internet!** diff --git a/client_conn.go b/client_conn.go index d665cc5..ced6923 100644 --- a/client_conn.go +++ b/client_conn.go @@ -95,32 +95,52 @@ func handleClt(cc *clientConn) { if cmd.SerializeVer != latestSerializeVer { cc.log("<--", "invalid serializeVer") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnsupportedVer}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } if cmd.MaxProtoVer < latestProtoVer { cc.log("<--", "invalid protoVer") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnsupportedVer}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } if len(cmd.PlayerName) == 0 || len(cmd.PlayerName) > maxPlayerNameLen { cc.log("<--", "invalid player name length") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.BadName}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } if ok, _ := regexp.MatchString(playerNameChars, cmd.PlayerName); !ok { cc.log("<--", "invalid player name") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.BadNameChars}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -131,8 +151,12 @@ func handleClt(cc *clientConn) { if ok { cc.log("<--", "already connected") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.AlreadyConnected}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } playersMu.Unlock() break @@ -144,8 +168,13 @@ func handleClt(cc *clientConn) { if cc.name == "singleplayer" { cc.log("<--", "name is singleplayer") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.BadName}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -153,8 +182,13 @@ func handleClt(cc *clientConn) { if len(players) >= conf.UserLimit { cc.log("<--", "player limit reached") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.TooManyClts}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -176,8 +210,13 @@ func handleClt(cc *clientConn) { if cc.auth.method != mt.FirstSRP { cc.log("-->", "unauthorized password change") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -186,16 +225,26 @@ func handleClt(cc *clientConn) { if cmd.EmptyPasswd && conf.RequirePasswd { cc.log("<--", "empty password disallowed") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.EmptyPasswd}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } if err := authIface.SetPasswd(cc.name, cmd.Salt, cmd.Verifier); err != nil { cc.log("<--", "set password fail") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.SrvErr}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -247,8 +296,13 @@ func handleClt(cc *clientConn) { } ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -263,8 +317,13 @@ func handleClt(cc *clientConn) { if err != nil { cc.log("<--", "SRP data retrieval fail") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.SrvErr}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -274,8 +333,13 @@ func handleClt(cc *clientConn) { if err != nil || cc.auth.srpB == nil { cc.log("<--", "SRP safety check fail") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -299,8 +363,13 @@ func handleClt(cc *clientConn) { } ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } @@ -328,8 +397,13 @@ func handleClt(cc *clientConn) { cc.log("<--", "invalid password") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.WrongPasswd}) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + break } case *mt.ToSrvInit2: @@ -47,16 +47,28 @@ func main() { sig := make(chan os.Signal) signal.Notify(sig, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) <-sig + l.close() mu.Lock() defer mu.Unlock() + var wg sync.WaitGroup + wg.Add(len(clts)) + for cc := range clts { - ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.Shutdown}) - <-ack - cc.Close() + go func() { + ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.Shutdown}) + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + + wg.Done() + }() } + wg.Wait() os.Exit(0) }() @@ -64,6 +76,7 @@ func main() { cc, err := l.accept() if err != nil { if errors.Is(err, net.ErrClosed) { + log.Print("{←|⇶} stop listening") break } @@ -94,8 +107,12 @@ func main() { Reason: mt.Custom, Custom: "No servers are configured.", }) - <-ack - cc.Close() + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + return } @@ -106,24 +123,36 @@ func main() { Reason: mt.Custom, Custom: "Server address resolution failed.", }) - <-ack - cc.Close() + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + return } conn, err := net.DialUDP("udp", nil, addr) if err != nil { cc.log("<--", "connection fail") + ack, _ := cc.SendCmd(&mt.ToCltDisco{ Reason: mt.Custom, Custom: "Server connection failed.", }) - <-ack - cc.Close() + + select { + case <-cc.Closed(): + case <-ack: + cc.Close() + } + return } connect(conn, cc) }() } + + select {} } diff --git a/server_conn.go b/server_conn.go index 5cd58bf..9a8e6f6 100644 --- a/server_conn.go +++ b/server_conn.go @@ -141,8 +141,12 @@ func handleSrv(sc *serverConn) { sc.log("<--", fmt.Sprintf("deny access %+v", cmd)) if sc.client() != nil { ack, _ := sc.client().SendCmd(cmd) - <-ack - sc.client().Close() + + select { + case <-sc.client().Closed(): + case <-ack: + sc.client().Close() + } } case *mt.ToCltAcceptAuth: sc.auth.method = 0 |