diff options
-rw-r--r-- | client_conn.go | 5 | ||||
-rw-r--r-- | content.go | 12 | ||||
-rw-r--r-- | listen.go | 17 | ||||
-rw-r--r-- | main.go | 27 | ||||
-rw-r--r-- | server_conn.go | 22 |
5 files changed, 58 insertions, 25 deletions
diff --git a/client_conn.go b/client_conn.go index cec49df..29db611 100644 --- a/client_conn.go +++ b/client_conn.go @@ -122,8 +122,11 @@ func handleClt(cc *clientConn) { if cc.server() != nil { cc.server().Close() - cc.mu.Lock() + cc.server().mu.Lock() cc.server().clt = nil + cc.server().mu.Unlock() + + cc.mu.Lock() cc.srv = nil cc.mu.Unlock() } @@ -77,6 +77,18 @@ func handleContent(cc *contentConn) { defer close(cc.doneCh) go func() { + init := make(chan struct{}) + defer close(init) + + go func(init <-chan struct{}) { + select { + case <-init: + case <-time.After(10 * time.Second): + cc.log("-->", "timeout") + cc.Close() + } + }(init) + for cc.state() == csCreated { cc.SendCmd(&mt.ToSrvInit{ SerializeVer: latestSerializeVer, @@ -3,17 +3,22 @@ package main import ( "fmt" "net" + "sync" "github.com/anon55555/mt" ) type listener struct { mtListener mt.Listener + mu sync.Mutex + + clts map[*clientConn]struct{} } func listen(pc net.PacketConn) *listener { return &listener{ mtListener: mt.Listen(pc), + clts: make(map[*clientConn]struct{}), } } @@ -35,6 +40,18 @@ func (l *listener) accept() (*clientConn, error) { modChs: make(map[string]struct{}), } + l.mu.Lock() + l.clts[cc] = struct{}{} + l.mu.Unlock() + + go func() { + <-cc.Closed() + l.mu.Lock() + defer l.mu.Unlock() + + delete(l.clts, cc) + }() + cc.log("-->", "connect") go handleClt(cc) @@ -40,21 +40,18 @@ func main() { log.Print("{←|⇶} listen ", l.addr()) - clts := make(map[*clientConn]struct{}) - var mu sync.Mutex - go func() { sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) <-sig - mu.Lock() - defer mu.Unlock() + l.mu.Lock() + defer l.mu.Unlock() var wg sync.WaitGroup - wg.Add(len(clts)) + wg.Add(len(l.clts)) - for cc := range clts { + for cc := range l.clts { go func(cc *clientConn) { ack, _ := cc.SendCmd(&mt.ToCltDisco{Reason: mt.Shutdown}) select { @@ -64,7 +61,10 @@ func main() { } <-cc.server().Closed() + cc.mu.Lock() cc.srv = nil + cc.mu.Unlock() + wg.Done() }(cc) } @@ -85,19 +85,6 @@ func main() { continue } - mu.Lock() - clts[cc] = struct{}{} - mu.Unlock() - - go func() { - <-cc.Closed() - - mu.Lock() - defer mu.Unlock() - - delete(clts, cc) - }() - go func() { <-cc.init() cc.log("<->", "handshake completed") diff --git a/server_conn.go b/server_conn.go index ec53b6a..673429f 100644 --- a/server_conn.go +++ b/server_conn.go @@ -83,11 +83,19 @@ func (sc *serverConn) log(dir string, v ...interface{}) { } func handleSrv(sc *serverConn) { - if sc.client() == nil { - sc.log("-->", "no associated client") - } - go func() { + init := make(chan struct{}) + defer close(init) + + go func(init <-chan struct{}) { + select { + case <-init: + case <-time.After(10 * time.Second): + sc.log("-->", "timeout") + sc.Close() + } + }(init) + for sc.state() == csCreated && sc.client() != nil { sc.SendCmd(&mt.ToSrvInit{ SerializeVer: latestSerializeVer, @@ -119,8 +127,14 @@ func handleSrv(sc *serverConn) { case <-sc.client().Closed(): case <-ack: sc.client().Close() + + sc.client().mu.Lock() sc.client().srv = nil + sc.client().mu.Unlock() + + sc.mu.Lock() sc.clt = nil + sc.mu.Unlock() } } |