aboutsummaryrefslogtreecommitdiff
path: root/plugin_srvselect.go
diff options
context:
space:
mode:
authorHimbeer <himbeer@disroot.org>2024-11-17 17:37:57 +0100
committerHimbeer <himbeer@disroot.org>2024-11-17 17:43:45 +0100
commit3bff2563fae6af73013e964e7e08109cea6fef4f (patch)
tree0a7232d98862af1ad59f9abdc01c73ef4dc17c94 /plugin_srvselect.go
parent7b69e587118c299d1601680ef70241283f30a009 (diff)
Allow plugins to override server selection when a client joins
Closes #129.
Diffstat (limited to 'plugin_srvselect.go')
-rw-r--r--plugin_srvselect.go66
1 files changed, 66 insertions, 0 deletions
diff --git a/plugin_srvselect.go b/plugin_srvselect.go
new file mode 100644
index 0000000..327efd5
--- /dev/null
+++ b/plugin_srvselect.go
@@ -0,0 +1,66 @@
+package proxy
+
+import (
+ "errors"
+ "sync"
+)
+
+var (
+ ErrEmptySrvSelectorName = errors.New("server selector name is empty")
+)
+
+var (
+ srvSelectors map[string]func(*ClientConn) (string, Server)
+ srvSelectorMu sync.RWMutex
+ srvSelectorOnce sync.Once
+)
+
+// RegisterSrvSelector registers a server selection handler
+// which can be enabled using the `SrvSelector` config option.
+// Empty names are an error.
+// If the handler returns an empty server name,
+// the regular server selection procedure is used.
+// Only one server selector can be active at a time.
+func RegisterSrvSelector(name string, sel func(*ClientConn) (string, Server)) error {
+ initSrvSelectors()
+
+ if name == "" {
+ return ErrEmptySrvSelectorName
+ }
+
+ srvSelectorMu.Lock()
+ defer srvSelectorMu.Unlock()
+
+ srvSelectors[name] = sel
+ return nil
+}
+
+func selectSrv(cc *ClientConn) (string, Server) {
+ sel := Conf().SrvSelector
+
+ if sel == "" {
+ return "", Server{}
+ }
+
+ initSrvSelectors()
+
+ srvSelectorMu.RLock()
+ defer srvSelectorMu.RUnlock()
+
+ handler, ok := srvSelectors[sel]
+ if !ok {
+ cc.Log("<-", "server selector not registered")
+ return "", Server{}
+ }
+
+ return handler(cc)
+}
+
+func initSrvSelectors() {
+ srvSelectorOnce.Do(func() {
+ srvSelectorMu.Lock()
+ defer srvSelectorMu.Unlock()
+
+ srvSelectors = make(map[string]func(*ClientConn) (string, Server))
+ })
+}