aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client_conn.go55
-rw-r--r--connect.go3
-rw-r--r--content.go22
-rw-r--r--server_conn.go31
4 files changed, 81 insertions, 30 deletions
diff --git a/client_conn.go b/client_conn.go
index 8eb5dcd..9b0309f 100644
--- a/client_conn.go
+++ b/client_conn.go
@@ -26,11 +26,13 @@ const (
type clientConn struct {
mt.Peer
srv *serverConn
+ mu sync.RWMutex
- state clientState
- name string
- initCh chan struct{}
- hopMu sync.Mutex
+ cstate clientState
+ cstateMu sync.RWMutex
+ name string
+ initCh chan struct{}
+ hopMu sync.Mutex
auth struct {
method mt.AuthMethods
@@ -58,7 +60,26 @@ type clientConn struct {
modChs map[string]struct{}
}
-func (cc *clientConn) server() *serverConn { return cc.srv }
+func (cc *clientConn) server() *serverConn {
+ cc.mu.RLock()
+ defer cc.mu.RUnlock()
+
+ return cc.srv
+}
+
+func (cc *clientConn) state() clientState {
+ cc.cstateMu.RLock()
+ defer cc.cstateMu.RUnlock()
+
+ return cc.cstate
+}
+
+func (cc *clientConn) setState(state clientState) {
+ cc.cstateMu.Lock()
+ defer cc.cstateMu.Unlock()
+
+ cc.cstate = state
+}
func (cc *clientConn) init() <-chan struct{} { return cc.initCh }
@@ -102,13 +123,12 @@ func handleClt(cc *clientConn) {
switch cmd := pkt.Cmd.(type) {
case *mt.ToSrvInit:
- if cc.state > csCreated {
+ if cc.state() > csCreated {
cc.log("-->", "duplicate init")
break
}
- cc.state = csInit
-
+ cc.setState(csInit)
if cmd.SerializeVer != latestSerializeVer {
cc.log("<--", "invalid serializeVer")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnsupportedVer})
@@ -223,7 +243,7 @@ func handleClt(cc *clientConn) {
Username: cc.name,
})
case *mt.ToSrvFirstSRP:
- if cc.state == csInit {
+ if cc.state() == csInit {
if cc.auth.method != mt.FirstSRP {
cc.log("-->", "unauthorized password change")
ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.UnexpectedData})
@@ -276,13 +296,12 @@ func handleClt(cc *clientConn) {
SudoAuthMethods: mt.SRP,
})
} else {
- if cc.state < csSudo {
+ if cc.state() < csSudo {
cc.log("-->", "unauthorized sudo action")
break
}
- cc.state--
-
+ cc.setState(cc.state() - 1)
if err := authIface.SetPasswd(cc.name, cmd.Salt, cmd.Verifier); err != nil {
cc.log("<--", "change password fail")
cc.SendCmd(&mt.ToCltChatMsg{
@@ -301,9 +320,9 @@ func handleClt(cc *clientConn) {
})
}
case *mt.ToSrvSRPBytesA:
- wantSudo := cc.state == csActive
+ wantSudo := cc.state() == csActive
- if cc.state != csInit && cc.state != csActive {
+ if cc.state() != csInit && cc.state() != csActive {
cc.log("-->", "unexpected authentication")
break
}
@@ -368,9 +387,9 @@ func handleClt(cc *clientConn) {
B: cc.auth.srpB,
})
case *mt.ToSrvSRPBytesM:
- wantSudo := cc.state == csActive
+ wantSudo := cc.state() == csActive
- if cc.state != csInit && cc.state != csActive {
+ if cc.state() != csInit && cc.state() != csActive {
cc.log("-->", "unexpected authentication")
break
}
@@ -401,7 +420,7 @@ func handleClt(cc *clientConn) {
}{}
if wantSudo {
- cc.state++
+ cc.setState(cc.state() + 1)
cc.SendCmd(&mt.ToCltAcceptSudoMode{})
} else {
cc.SendCmd(&mt.ToCltAcceptAuth{
@@ -485,7 +504,7 @@ func handleClt(cc *clientConn) {
cc.versionStr = cmd.Version
cc.formspecVer = cmd.Formspec
- cc.state++
+ cc.setState(cc.state() + 1)
close(cc.initCh)
case *mt.ToSrvInteract:
if cc.server() == nil {
diff --git a/connect.go b/connect.go
index 9ad3a1e..03f9243 100644
--- a/connect.go
+++ b/connect.go
@@ -19,7 +19,10 @@ func connect(conn net.Conn, name string, cc *clientConn) *serverConn {
playerList: make(map[string]struct{}),
}
sc.log("-->", "connect")
+
+ cc.mu.Lock()
cc.srv = sc
+ cc.mu.Unlock()
go handleSrv(sc)
return sc
diff --git a/content.go b/content.go
index a8a9381..079739e 100644
--- a/content.go
+++ b/content.go
@@ -24,7 +24,8 @@ type mediaFile struct {
type contentConn struct {
mt.Peer
- state clientState
+ cstate clientState
+ cstateMu sync.RWMutex
name, userName string
doneCh chan struct{}
@@ -41,6 +42,20 @@ type contentConn struct {
media []mediaFile
}
+func (cc *contentConn) state() clientState {
+ cc.cstateMu.RLock()
+ defer cc.cstateMu.RUnlock()
+
+ return cc.cstate
+}
+
+func (cc *contentConn) setState(state clientState) {
+ cc.cstateMu.Lock()
+ defer cc.cstateMu.Unlock()
+
+ cc.cstate = state
+}
+
func (cc *contentConn) done() <-chan struct{} { return cc.doneCh }
func (cc *contentConn) log(dir, msg string) {
@@ -51,7 +66,7 @@ func handleContent(cc *contentConn) {
defer close(cc.doneCh)
go func() {
- for cc.state == csCreated {
+ for cc.state() == csCreated {
cc.SendCmd(&mt.ToSrvInit{
SerializeVer: latestSerializeVer,
MinProtoVer: latestProtoVer,
@@ -84,8 +99,7 @@ func handleContent(cc *contentConn) {
break
}
- cc.state++
-
+ cc.setState(cc.state() + 1)
if cmd.AuthMethods&mt.FirstSRP != 0 {
cc.auth.method = mt.FirstSRP
} else {
diff --git a/server_conn.go b/server_conn.go
index 45e11fa..3e0bac5 100644
--- a/server_conn.go
+++ b/server_conn.go
@@ -6,6 +6,7 @@ import (
"log"
"net"
"strings"
+ "sync"
"time"
"github.com/HimbeerserverDE/srp"
@@ -17,9 +18,10 @@ type serverConn struct {
mt.Peer
clt *clientConn
- state clientState
- name string
- initCh chan struct{}
+ cstate clientState
+ cstateMu sync.RWMutex
+ name string
+ initCh chan struct{}
auth struct {
method mt.AuthMethods
@@ -41,6 +43,20 @@ type serverConn struct {
func (sc *serverConn) client() *clientConn { return sc.clt }
+func (sc *serverConn) state() clientState {
+ sc.cstateMu.RLock()
+ defer sc.cstateMu.RUnlock()
+
+ return sc.cstate
+}
+
+func (sc *serverConn) setState(state clientState) {
+ sc.cstateMu.Lock()
+ defer sc.cstateMu.Unlock()
+
+ sc.cstate = state
+}
+
func (sc *serverConn) init() <-chan struct{} { return sc.initCh }
func (sc *serverConn) log(dir, msg string) {
@@ -57,7 +73,7 @@ func handleSrv(sc *serverConn) {
}
go func() {
- for sc.state == csCreated && sc.client() != nil {
+ for sc.state() == csCreated && sc.client() != nil {
sc.SendCmd(&mt.ToSrvInit{
SerializeVer: latestSerializeVer,
MinProtoVer: latestProtoVer,
@@ -108,8 +124,7 @@ func handleSrv(sc *serverConn) {
break
}
- sc.state++
-
+ sc.setState(sc.state() + 1)
if cmd.AuthMethods&mt.FirstSRP != 0 {
sc.auth.method = mt.FirstSRP
} else {
@@ -190,7 +205,7 @@ func handleSrv(sc *serverConn) {
sc.log("<--", "deny sudo")
case *mt.ToCltAcceptSudoMode:
sc.log("<--", "accept sudo")
- sc.state++
+ sc.setState(sc.state() + 1)
case *mt.ToCltAnnounceMedia:
sc.SendCmd(&mt.ToSrvReqMedia{})
@@ -204,7 +219,7 @@ func handleSrv(sc *serverConn) {
})
sc.log("<->", "handshake completed")
- sc.state++
+ sc.setState(sc.state() + 1)
close(sc.initCh)
case *mt.ToCltInv:
var oldInv mt.Inv