aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHimbeerserverDE <himbeerserverde@gmail.com>2021-08-28 15:06:18 +0200
committerHimbeerserverDE <himbeerserverde@gmail.com>2021-08-28 15:06:18 +0200
commit4bd36bff71f73425397043f538e819448973928f (patch)
tree5014f2b4ba06ead584a96cbe6dfad8ebe7267aa9
parente45b9940473ad52555296005f4beef22aacc8276 (diff)
Don't wait for ack if the Peer has disconnected
-rw-r--r--README.md6
-rw-r--r--client_conn.go134
-rw-r--r--main.go47
-rw-r--r--server_conn.go8
4 files changed, 154 insertions, 41 deletions
diff --git a/README.md b/README.md
index 8451fe7..3f97f1f 100644
--- a/README.md
+++ b/README.md
@@ -8,10 +8,16 @@ Go 1.16 or higher is required. Run
to download and compile the project. A MTConnector executable
will be created in your $GOBIN directory.
## Usage
+### Starting
Run `$GOBIN/MTConnector`. The configuration file and other required
files are created automatically in the directory the executable
(or symlink to said executable) is in, so make sure to move the
executable to the desired location or use a symlink.
+### Stopping
+MTConnector reacts to SIGINT, SIGTERM and SIGHUP. It stops listening
+for new connections, kicks all clients, disconnects from all servers
+and exits. If some clients aren't responding, MTConnector waits until
+they have timed out.
## Configuration
The configuration file name and format are described in [doc/config.md](doc/config.md)
**All internal servers need to allow empty passwords and must not be reachable from the internet!**
diff --git a/client_conn.go b/client_conn.go
index d665cc5..ced6923 100644
--- a/client_conn.go
+++ b/client_conn.go
@@ -95,32 +95,52 @@ func handleClt(cc *clientConn) {
if cmd.SerializeVer != latestSerializeVer {
cc.log("<--", "invalid serializeVer")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnsupportedVer})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
if cmd.MaxProtoVer < latestProtoVer {
cc.log("<--", "invalid protoVer")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnsupportedVer})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
if len(cmd.PlayerName) == 0 || len(cmd.PlayerName) > maxPlayerNameLen {
cc.log("<--", "invalid player name length")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.BadName})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
if ok, _ := regexp.MatchString(playerNameChars, cmd.PlayerName); !ok {
cc.log("<--", "invalid player name")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.BadNameChars})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -131,8 +151,12 @@ func handleClt(cc *clientConn) {
if ok {
cc.log("<--", "already connected")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.AlreadyConnected})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
playersMu.Unlock()
break
@@ -144,8 +168,13 @@ func handleClt(cc *clientConn) {
if cc.name == "singleplayer" {
cc.log("<--", "name is singleplayer")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.BadName})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -153,8 +182,13 @@ func handleClt(cc *clientConn) {
if len(players) >= conf.UserLimit {
cc.log("<--", "player limit reached")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.TooManyClts})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -176,8 +210,13 @@ func handleClt(cc *clientConn) {
if cc.auth.method != mt.FirstSRP {
cc.log("-->", "unauthorized password change")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -186,16 +225,26 @@ func handleClt(cc *clientConn) {
if cmd.EmptyPasswd && conf.RequirePasswd {
cc.log("<--", "empty password disallowed")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.EmptyPasswd})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
if err := authIface.SetPasswd(cc.name, cmd.Salt, cmd.Verifier); err != nil {
cc.log("<--", "set password fail")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.SrvErr})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -247,8 +296,13 @@ func handleClt(cc *clientConn) {
}
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -263,8 +317,13 @@ func handleClt(cc *clientConn) {
if err != nil {
cc.log("<--", "SRP data retrieval fail")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.SrvErr})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -274,8 +333,13 @@ func handleClt(cc *clientConn) {
if err != nil || cc.auth.srpB == nil {
cc.log("<--", "SRP safety check fail")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -299,8 +363,13 @@ func handleClt(cc *clientConn) {
}
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
@@ -328,8 +397,13 @@ func handleClt(cc *clientConn) {
cc.log("<--", "invalid password")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.WrongPasswd})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
break
}
case *mt.ToSrvInit2:
diff --git a/main.go b/main.go
index b2bf21c..bad9974 100644
--- a/main.go
+++ b/main.go
@@ -47,16 +47,28 @@ func main() {
sig := make(chan os.Signal)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
<-sig
+ l.close()
mu.Lock()
defer mu.Unlock()
+ var wg sync.WaitGroup
+ wg.Add(len(clts))
+
for cc := range clts {
- ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.Shutdown})
- <-ack
- cc.Close()
+ go func() {
+ ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.Shutdown})
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
+ wg.Done()
+ }()
}
+ wg.Wait()
os.Exit(0)
}()
@@ -64,6 +76,7 @@ func main() {
cc, err := l.accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
+ log.Print("{←|⇶} stop listening")
break
}
@@ -94,8 +107,12 @@ func main() {
Reason: mt.Custom,
Custom: "No servers are configured.",
})
- <-ack
- cc.Close()
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
return
}
@@ -106,24 +123,36 @@ func main() {
Reason: mt.Custom,
Custom: "Server address resolution failed.",
})
- <-ack
- cc.Close()
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
return
}
conn, err := net.DialUDP("udp", nil, addr)
if err != nil {
cc.log("<--", "connection fail")
+
ack, _ := cc.SendCmd(&mt.ToCltDisco{
Reason: mt.Custom,
Custom: "Server connection failed.",
})
- <-ack
- cc.Close()
+
+ select {
+ case <-cc.Closed():
+ case <-ack:
+ cc.Close()
+ }
+
return
}
connect(conn, cc)
}()
}
+
+ select {}
}
diff --git a/server_conn.go b/server_conn.go
index 5cd58bf..9a8e6f6 100644
--- a/server_conn.go
+++ b/server_conn.go
@@ -141,8 +141,12 @@ func handleSrv(sc *serverConn) {
sc.log("<--", fmt.Sprintf("deny access %+v", cmd))
if sc.client() != nil {
ack, _ := sc.client().SendCmd(cmd)
- <-ack
- sc.client().Close()
+
+ select {
+ case <-sc.client().Closed():
+ case <-ack:
+ sc.client().Close()
+ }
}
case *mt.ToCltAcceptAuth:
sc.auth.method = 0