diff options
Diffstat (limited to 'client_conn.go')
-rw-r--r-- | client_conn.go | 55 |
1 files changed, 37 insertions, 18 deletions
diff --git a/client_conn.go b/client_conn.go index 8eb5dcd..9b0309f 100644 --- a/client_conn.go +++ b/client_conn.go @@ -26,11 +26,13 @@ const ( type clientConn struct { mt.Peer srv *serverConn + mu sync.RWMutex - state clientState - name string - initCh chan struct{} - hopMu sync.Mutex + cstate clientState + cstateMu sync.RWMutex + name string + initCh chan struct{} + hopMu sync.Mutex auth struct { method mt.AuthMethods @@ -58,7 +60,26 @@ type clientConn struct { modChs map[string]struct{} } -func (cc *clientConn) server() *serverConn { return cc.srv } +func (cc *clientConn) server() *serverConn { + cc.mu.RLock() + defer cc.mu.RUnlock() + + return cc.srv +} + +func (cc *clientConn) state() clientState { + cc.cstateMu.RLock() + defer cc.cstateMu.RUnlock() + + return cc.cstate +} + +func (cc *clientConn) setState(state clientState) { + cc.cstateMu.Lock() + defer cc.cstateMu.Unlock() + + cc.cstate = state +} func (cc *clientConn) init() <-chan struct{} { return cc.initCh } @@ -102,13 +123,12 @@ func handleClt(cc *clientConn) { switch cmd := pkt.Cmd.(type) { case *mt.ToSrvInit: - if cc.state > csCreated { + if cc.state() > csCreated { cc.log("-->", "duplicate init") break } - cc.state = csInit - + cc.setState(csInit) if cmd.SerializeVer != latestSerializeVer { cc.log("<--", "invalid serializeVer") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnsupportedVer}) @@ -223,7 +243,7 @@ func handleClt(cc *clientConn) { Username: cc.name, }) case *mt.ToSrvFirstSRP: - if cc.state == csInit { + if cc.state() == csInit { if cc.auth.method != mt.FirstSRP { cc.log("-->", "unauthorized password change") ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData}) @@ -276,13 +296,12 @@ func handleClt(cc *clientConn) { SudoAuthMethods: mt.SRP, }) } else { - if cc.state < csSudo { + if cc.state() < csSudo { cc.log("-->", "unauthorized sudo action") break } - cc.state-- - + cc.setState(cc.state() - 1) if err := authIface.SetPasswd(cc.name, cmd.Salt, cmd.Verifier); err != nil { cc.log("<--", "change password fail") cc.SendCmd(&mt.ToCltChatMsg{ @@ -301,9 +320,9 @@ func handleClt(cc *clientConn) { }) } case *mt.ToSrvSRPBytesA: - wantSudo := cc.state == csActive + wantSudo := cc.state() == csActive - if cc.state != csInit && cc.state != csActive { + if cc.state() != csInit && cc.state() != csActive { cc.log("-->", "unexpected authentication") break } @@ -368,9 +387,9 @@ func handleClt(cc *clientConn) { B: cc.auth.srpB, }) case *mt.ToSrvSRPBytesM: - wantSudo := cc.state == csActive + wantSudo := cc.state() == csActive - if cc.state != csInit && cc.state != csActive { + if cc.state() != csInit && cc.state() != csActive { cc.log("-->", "unexpected authentication") break } @@ -401,7 +420,7 @@ func handleClt(cc *clientConn) { }{} if wantSudo { - cc.state++ + cc.setState(cc.state() + 1) cc.SendCmd(&mt.ToCltAcceptSudoMode{}) } else { cc.SendCmd(&mt.ToCltAcceptAuth{ @@ -485,7 +504,7 @@ func handleClt(cc *clientConn) { cc.versionStr = cmd.Version cc.formspecVer = cmd.Formspec - cc.state++ + cc.setState(cc.state() + 1) close(cc.initCh) case *mt.ToSrvInteract: if cc.server() == nil { |