diff options
author | HimbeerserverDE <himbeerserverde@gmail.com> | 2021-09-03 13:26:44 +0200 |
---|---|---|
committer | HimbeerserverDE <himbeerserverde@gmail.com> | 2021-09-03 13:26:44 +0200 |
commit | f4ccac8d93cfcbfe5ef1c6b4cc9c92220958ae71 (patch) | |
tree | f8b70a40303b88c34bfc1e915e187012d26a3212 | |
parent | 55ba6a3805956764f9261c70a992e16486b7b3eb (diff) |
Fix race conditions (#36)
-rw-r--r-- | client_conn.go | 55 | ||||
-rw-r--r-- | connect.go | 3 | ||||
-rw-r--r-- | content.go | 22 | ||||
-rw-r--r-- | server_conn.go | 31 |
4 files changed, 81 insertions, 30 deletions
diff --git a/client_conn.go b/client_conn.go index 8eb5dcd..9b0309f 100644 --- a/client_conn.go +++ b/client_conn.go @@ -26,11 +26,13 @@ const ( type clientConn struct { mt.Peer srv *serverConn + mu sync.RWMutex - state clientState - name string - initCh chan struct{} - hopMu sync.Mutex + cstate clientState + cstateMu sync.RWMutex + name string + initCh chan struct{} + hopMu sync.Mutex auth struct { method mt.AuthMethods @@ -58,7 +60,26 @@ type clientConn struct { modChs map[string]struct{} } -func (cc *clientConn) server() *serverConn { return cc.srv } +func (cc *clientConn) server() *serverConn { + cc.mu.RLock() + defer cc.mu.RUnlock() + + return cc.srv +} + +func (cc *clientConn) state() clientState { + cc.cstateMu.RLock() + defer cc.cstateMu.RUnlock() + + return cc.cstate +} + +func (cc *clientConn) setState(state clientState) { + cc.cstateMu.Lock() + defer cc.cstateMu.Unlock() + + cc.cstate = state +} func (cc *clientConn) init() <-chan struct{} { return cc.initCh } @@ -102,13 +123,12 @@ func handleClt(cc *clientConn) { switch cmd := pkt.Cmd.(type) { case *mt.ToSrvInit: - if cc.state > csCreated { + if cc.state() > csCreated { cc.log("-->", "duplicate init") break } - cc.state = csInit - + cc.setState(csInit) if cmd.SerializeVer != latestSerializeVer { cc.log("<--", "invalid serializeVer") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnsupportedVer}) @@ -223,7 +243,7 @@ func handleClt(cc *clientConn) { Username: cc.name, }) case *mt.ToSrvFirstSRP: - if cc.state == csInit { + if cc.state() == csInit { if cc.auth.method != mt.FirstSRP { cc.log("-->", "unauthorized password change") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData}) @@ -276,13 +296,12 @@ func handleClt(cc *clientConn) { SudoAuthMethods: mt.SRP, }) } else { - if cc.state < csSudo { + if cc.state() < csSudo { cc.log("-->", "unauthorized sudo action") break } - cc.state-- - + cc.setState(cc.state() - 1) if err := authIface.SetPasswd(cc.name, cmd.Salt, cmd.Verifier); err != nil { cc.log("<--", "change password fail") cc.SendCmd(&mt.ToCltChatMsg{ @@ -301,9 +320,9 @@ func handleClt(cc *clientConn) { }) } case *mt.ToSrvSRPBytesA: - wantSudo := cc.state == csActive + wantSudo := cc.state() == csActive - if cc.state != csInit && cc.state != csActive { + if cc.state() != csInit && cc.state() != csActive { cc.log("-->", "unexpected authentication") break } @@ -368,9 +387,9 @@ func handleClt(cc *clientConn) { B: cc.auth.srpB, }) case *mt.ToSrvSRPBytesM: - wantSudo := cc.state == csActive + wantSudo := cc.state() == csActive - if cc.state != csInit && cc.state != csActive { + if cc.state() != csInit && cc.state() != csActive { cc.log("-->", "unexpected authentication") break } @@ -401,7 +420,7 @@ func handleClt(cc *clientConn) { }{} if wantSudo { - cc.state++ + cc.setState(cc.state() + 1) cc.SendCmd(&mt.ToCltAcceptSudoMode{}) } else { cc.SendCmd(&mt.ToCltAcceptAuth{ @@ -485,7 +504,7 @@ func handleClt(cc *clientConn) { cc.versionStr = cmd.Version cc.formspecVer = cmd.Formspec - cc.state++ + cc.setState(cc.state() + 1) close(cc.initCh) case *mt.ToSrvInteract: if cc.server() == nil { @@ -19,7 +19,10 @@ func connect(conn net.Conn, name string, cc *clientConn) *serverConn { playerList: make(map[string]struct{}), } sc.log("-->", "connect") + + cc.mu.Lock() cc.srv = sc + cc.mu.Unlock() go handleSrv(sc) return sc @@ -24,7 +24,8 @@ type mediaFile struct { type contentConn struct { mt.Peer - state clientState + cstate clientState + cstateMu sync.RWMutex name, userName string doneCh chan struct{} @@ -41,6 +42,20 @@ type contentConn struct { media []mediaFile } +func (cc *contentConn) state() clientState { + cc.cstateMu.RLock() + defer cc.cstateMu.RUnlock() + + return cc.cstate +} + +func (cc *contentConn) setState(state clientState) { + cc.cstateMu.Lock() + defer cc.cstateMu.Unlock() + + cc.cstate = state +} + func (cc *contentConn) done() <-chan struct{} { return cc.doneCh } func (cc *contentConn) log(dir, msg string) { @@ -51,7 +66,7 @@ func handleContent(cc *contentConn) { defer close(cc.doneCh) go func() { - for cc.state == csCreated { + for cc.state() == csCreated { cc.SendCmd(&mt.ToSrvInit{ SerializeVer: latestSerializeVer, MinProtoVer: latestProtoVer, @@ -84,8 +99,7 @@ func handleContent(cc *contentConn) { break } - cc.state++ - + cc.setState(cc.state() + 1) if cmd.AuthMethods&mt.FirstSRP != 0 { cc.auth.method = mt.FirstSRP } else { diff --git a/server_conn.go b/server_conn.go index 45e11fa..3e0bac5 100644 --- a/server_conn.go +++ b/server_conn.go @@ -6,6 +6,7 @@ import ( "log" "net" "strings" + "sync" "time" "github.com/HimbeerserverDE/srp" @@ -17,9 +18,10 @@ type serverConn struct { mt.Peer clt *clientConn - state clientState - name string - initCh chan struct{} + cstate clientState + cstateMu sync.RWMutex + name string + initCh chan struct{} auth struct { method mt.AuthMethods @@ -41,6 +43,20 @@ type serverConn struct { func (sc *serverConn) client() *clientConn { return sc.clt } +func (sc *serverConn) state() clientState { + sc.cstateMu.RLock() + defer sc.cstateMu.RUnlock() + + return sc.cstate +} + +func (sc *serverConn) setState(state clientState) { + sc.cstateMu.Lock() + defer sc.cstateMu.Unlock() + + sc.cstate = state +} + func (sc *serverConn) init() <-chan struct{} { return sc.initCh } func (sc *serverConn) log(dir, msg string) { @@ -57,7 +73,7 @@ func handleSrv(sc *serverConn) { } go func() { - for sc.state == csCreated && sc.client() != nil { + for sc.state() == csCreated && sc.client() != nil { sc.SendCmd(&mt.ToSrvInit{ SerializeVer: latestSerializeVer, MinProtoVer: latestProtoVer, @@ -108,8 +124,7 @@ func handleSrv(sc *serverConn) { break } - sc.state++ - + sc.setState(sc.state() + 1) if cmd.AuthMethods&mt.FirstSRP != 0 { sc.auth.method = mt.FirstSRP } else { @@ -190,7 +205,7 @@ func handleSrv(sc *serverConn) { sc.log("<--", "deny sudo") case *mt.ToCltAcceptSudoMode: sc.log("<--", "accept sudo") - sc.state++ + sc.setState(sc.state() + 1) case *mt.ToCltAnnounceMedia: sc.SendCmd(&mt.ToSrvReqMedia{}) @@ -204,7 +219,7 @@ func handleSrv(sc *serverConn) { }) sc.log("<->", "handshake completed") - sc.state++ + sc.setState(sc.state() + 1) close(sc.initCh) case *mt.ToCltInv: var oldInv mt.Inv |