From a5e6fa72ffb7a43048c996e175e12116fe4c9fc1 Mon Sep 17 00:00:00 2001 From: ChronosX88 Date: Mon, 10 Feb 2020 13:23:10 +0400 Subject: [PATCH] Implement basic module system, auth manager and connection handler --- cmd/zr/main.go | 29 ++--------- go.mod | 2 + go.sum | 6 +++ internal/auth_manager.go | 79 ++++++++++++++++++++++++++++ internal/connection_handler.go | 39 ++++++++++++++ internal/globals.go | 36 +++++++++++++ internal/handler.go | 1 - internal/models/base_message.go | 27 ++++++++++ internal/models/entity_id.go | 50 ++++++++++++++++++ internal/models/protocol_error.go | 7 +++ internal/module_manager.go | 86 +++++++++++++++++++++++++++++++ internal/origin_c2s.go | 17 ++++++ internal/router.go | 69 +++++++++++++++++++++++++ internal/utils.go | 50 ++++++++++++++++++ plugins/.gitkeep | 0 shared/module.go | 5 ++ shared/plugin_map.go | 5 ++ 17 files changed, 481 insertions(+), 27 deletions(-) create mode 100644 internal/auth_manager.go create mode 100644 internal/connection_handler.go create mode 100644 internal/globals.go delete mode 100644 internal/handler.go create mode 100644 internal/models/base_message.go create mode 100644 internal/models/entity_id.go create mode 100644 internal/models/protocol_error.go create mode 100644 internal/module_manager.go create mode 100644 internal/origin_c2s.go create mode 100644 internal/router.go create mode 100644 internal/utils.go create mode 100644 plugins/.gitkeep create mode 100644 shared/module.go create mode 100644 shared/plugin_map.go diff --git a/cmd/zr/main.go b/cmd/zr/main.go index 8d30f05..2285e87 100644 --- a/cmd/zr/main.go +++ b/cmd/zr/main.go @@ -4,13 +4,12 @@ import ( "log" "net/http" - "github.com/google/uuid" + "github.com/ChronosX88/zirconium/internal" "github.com/gorilla/mux" "github.com/gorilla/websocket" ) -var clients = make(map[*websocket.Conn]string) -var clientsReverse = make(map[string]*websocket.Conn) +var connectionHandler = internal.NewConnectionHandler() var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -33,27 +32,5 @@ func wsHandler(w http.ResponseWriter, r *http.Request) { if err != nil { log.Fatal(err) } - - // register client - clients[ws] = uuid.New().String() - clientsReverse[clients[ws]] = ws - go readLoop(ws) - log.Printf("Connection %s created!", clients[ws]) -} - -func readLoop(c *websocket.Conn) { - for { - if _, _, err := c.NextReader(); err != nil { - connectionID := clients[c] - if connectionID != "" { - delete(clients, c) - delete(clientsReverse, connectionID) - log.Printf("Connection %s closed!", connectionID) - } else { - log.Println("connection wasn't found") - } - c.Close() - break - } - } + connectionHandler.HandleNewConnection(ws) } diff --git a/go.mod b/go.mod index 3817981..a42a13d 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/ChronosX88/zirconium go 1.13 require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/google/logger v1.0.1 github.com/google/uuid v1.1.1 github.com/gorilla/mux v1.7.3 github.com/gorilla/websocket v1.4.1 diff --git a/go.sum b/go.sum index dbe64b7..5c6bb36 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,12 @@ +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/google/logger v1.0.1 h1:Jtq7/44yDwUXMaLTYgXFC31zpm6Oku7OI/k4//yVANQ= +github.com/google/logger v1.0.1/go.mod h1:w7O8nrRr0xufejBlQMI83MXqRusvREoJdaAxV+CoAB4= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/auth_manager.go b/internal/auth_manager.go new file mode 100644 index 0000000..91a8282 --- /dev/null +++ b/internal/auth_manager.go @@ -0,0 +1,79 @@ +package internal + +import ( + "encoding/base64" + "fmt" + "time" + + "github.com/dgrijalva/jwt-go" +) + +const ( + SigningKeyBytesAmount = 4096 + TokenExpireTimeDuration = 24 * time.Hour +) + +type AuthManager struct { + signingKey string // For now it is random bytes string represented in Base64 +} + +type JWTCustomClaims struct { + EntityID string `json:"entityID"` + DeviceID string `json:"deviceID"` + jwt.StandardClaims +} + +func NewAuthManager() (*AuthManager, error) { + am := &AuthManager{} + bytes, err := GenRandomBytes(SigningKeyBytesAmount) + if err != nil { + return nil, err + } + am.signingKey = base64.RawStdEncoding.EncodeToString(bytes) + return am, nil +} + +func (am *AuthManager) CreateNewToken(entityID, deviceID string, tokenExpireTimeDuration time.Duration) (string, error) { + timeNow := time.Now() + expiringTime := timeNow.Add(tokenExpireTimeDuration) + claims := JWTCustomClaims{ + entityID, + deviceID, + jwt.StandardClaims{ + ExpiresAt: time.Date( + expiringTime.Year(), + expiringTime.Month(), + expiringTime.Day(), + expiringTime.Hour(), + expiringTime.Minute(), + expiringTime.Second(), + expiringTime.Nanosecond(), + time.UTC, + ).Unix(), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(am.signingKey) + if err != nil { + return "", err + } + return tokenString, nil +} + +func (am *AuthManager) ValidateToken(tokenString string) (bool, string, string, error) { + token, err := jwt.ParseWithClaims(tokenString, &JWTCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + // hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key") + return []byte(am.signingKey), nil + }) + + if claims, ok := token.Claims.(*JWTCustomClaims); ok && token.Valid { + return true, claims.EntityID, claims.DeviceID, nil + } + return false, "", "", err +} diff --git a/internal/connection_handler.go b/internal/connection_handler.go new file mode 100644 index 0000000..b5f1c42 --- /dev/null +++ b/internal/connection_handler.go @@ -0,0 +1,39 @@ +package internal + +import ( + "github.com/ChronosX88/zirconium/internal/models" + "github.com/google/logger" + "github.com/google/uuid" + "github.com/gorilla/websocket" +) + +type ConnectionHandler struct { + connections map[string]*OriginC2S +} + +func NewConnectionHandler() *ConnectionHandler { + return &ConnectionHandler{} +} + +func (ch *ConnectionHandler) HandleNewConnection(wsocket *websocket.Conn) { + uuid, _ := uuid.NewRandom() + uuidStr := uuid.String() + o := &OriginC2S{ + wsConn: wsocket, + connID: uuidStr, + } + ch.connections[o.connID] = o + go func() { + for { + var msg models.BaseMessage + err := o.wsConn.ReadJSON(&msg) + if err != nil { + delete(ch.connections, o.connID) + o.wsConn.Close() + logger.Infof("Connection %s was closed. (Reason: %s)", o.connID, err.Error()) + break + } + router.RouteMessage(o, msg) + } + }() +} diff --git a/internal/globals.go b/internal/globals.go new file mode 100644 index 0000000..c7f266c --- /dev/null +++ b/internal/globals.go @@ -0,0 +1,36 @@ +package internal + +import ( + "github.com/ChronosX88/zirconium/shared" + "github.com/google/logger" +) + +var ( + ModuleMgr *ModuleManager + router *Router + authManager *AuthManager + serverDomain string +) + +func InitializeContext(sDomain string) { + var err error + ModuleMgr, err = NewModuleManager() + if err != nil { + logger.Fatalf("Unable to initialize module manager: %s", err.Error()) + } + + router, err = NewRouter() + if err != nil { + logger.Fatalf("Unable to initialize router: %s", err.Error()) + } + + authManager, err = NewAuthManager() + if err != nil { + logger.Fatalf("Unable to initialize authentication manager: %s", err.Error()) + } + serverDomain = sDomain + + for _, v := range shared.Plugins { + go v.Initialize() // Initialize provided plugins + } +} diff --git a/internal/handler.go b/internal/handler.go deleted file mode 100644 index 5bf0569..0000000 --- a/internal/handler.go +++ /dev/null @@ -1 +0,0 @@ -package internal diff --git a/internal/models/base_message.go b/internal/models/base_message.go new file mode 100644 index 0000000..2225433 --- /dev/null +++ b/internal/models/base_message.go @@ -0,0 +1,27 @@ +package models + +import "github.com/google/uuid" + +// BaseMessage is a basic message model, basis of the whole protocol. It is used for a very easy protocol extension process. +type BaseMessage struct { + ID string `json:"id"` + MessageType string `json:"type"` + From string `json:"from"` + To string `json:"to"` + Ok bool `json:"ok"` + AuthToken string `json:"authToken"` + Payload map[string]interface{} `json:"payload"` +} + +func NewBaseMessage(messageType string, from string, to string, ok bool, payload map[string]interface{}) BaseMessage { + uuid, _ := uuid.NewRandom() + uuidStr := uuid.String() + return BaseMessage{ + ID: uuidStr, + MessageType: messageType, + From: from, + To: to, + Ok: ok, + Payload: payload, + } +} diff --git a/internal/models/entity_id.go b/internal/models/entity_id.go new file mode 100644 index 0000000..9688f80 --- /dev/null +++ b/internal/models/entity_id.go @@ -0,0 +1,50 @@ +package models + +import ( + "fmt" + "strings" +) + +type EntityIDType string + +const ( + UsernameType EntityIDType = "@" + RoomAliasType EntityIDType = "#" + RoomIDType EntityIDType = "!" +) + +type EntityID struct { + EntityIDType EntityIDType + LocalPart string + ServerPart string +} + +func NewEntityID(entityID string) *EntityID { + eID := &EntityID{} + switch EntityIDType(string(entityID[0])) { + case UsernameType: + { + eID.EntityIDType = UsernameType + } + case RoomAliasType: + { + eID.EntityIDType = RoomAliasType + } + case RoomIDType: + { + eID.EntityIDType = RoomIDType + } + } + localAndServerPart := strings.Split(entityID, "@") + if len(localAndServerPart) == 3 { + localAndServerPart = localAndServerPart[1:] + } + eID.LocalPart = localAndServerPart[0] + eID.ServerPart = localAndServerPart[1] + + return eID +} + +func (eID *EntityID) String() string { + return fmt.Sprintf("%s%s@%s", eID.EntityIDType, eID.LocalPart, eID.ServerPart) +} diff --git a/internal/models/protocol_error.go b/internal/models/protocol_error.go new file mode 100644 index 0000000..604fd28 --- /dev/null +++ b/internal/models/protocol_error.go @@ -0,0 +1,7 @@ +package models + +type ProtocolError struct { + ErrCode string `json:"errCode"` + ErrText string `json:"errText"` + ErrPayload map[string]interface{} `json:"errPayload"` +} diff --git a/internal/module_manager.go b/internal/module_manager.go new file mode 100644 index 0000000..63d3a11 --- /dev/null +++ b/internal/module_manager.go @@ -0,0 +1,86 @@ +package internal + +import ( + "reflect" + "sync" + "time" + + "github.com/ChronosX88/zirconium/internal/models" +) + +type C2SMessageHandler struct { + HandlerFunc func(origin *OriginC2S, message models.BaseMessage) + AnonymousAllowed bool +} + +type ModuleManager struct { + moduleMutex sync.Mutex + c2sMessageHandlers map[string][]*C2SMessageHandler + internalEventHandlers map[string][]func(sourceModuleName string, event map[string]interface{}) +} + +func NewModuleManager() (*ModuleManager, error) { + var mm = &ModuleManager{ + c2sMessageHandlers: make(map[string][]*C2SMessageHandler), + internalEventHandlers: make(map[string][]func(sourceModuleName string, event map[string]interface{})), + } + return mm, nil +} + +func (mm *ModuleManager) Hook(messageType string, anonymousAllowed bool, handlerFunc func(origin *OriginC2S, message models.BaseMessage)) { + mm.moduleMutex.Lock() + mm.c2sMessageHandlers[messageType] = append(mm.c2sMessageHandlers[messageType], &C2SMessageHandler{ + HandlerFunc: handlerFunc, + AnonymousAllowed: anonymousAllowed, + }) + mm.moduleMutex.Unlock() +} + +func (mm *ModuleManager) HookInternalEvent(eventName string, handlerFunc func(sourceModuleName string, event map[string]interface{})) { + mm.moduleMutex.Lock() + mm.internalEventHandlers[eventName] = append(mm.internalEventHandlers[eventName], handlerFunc) + mm.moduleMutex.Unlock() +} + +func (mm *ModuleManager) Unhook(messageType string, handlerFunc func(origin *OriginC2S, message models.BaseMessage)) { + mm.moduleMutex.Lock() + defer mm.moduleMutex.Unlock() + var handlers = mm.c2sMessageHandlers[messageType] + if handlers != nil { + for i, v := range handlers { + if reflect.ValueOf(v.HandlerFunc).Pointer() == reflect.ValueOf(handlerFunc).Pointer() { + handlers[i] = handlers[len(handlers)-1] + handlers[len(handlers)-1] = nil + handlers = handlers[:len(handlers)-1] + mm.c2sMessageHandlers[messageType] = handlers + break + } + } + } +} + +func (mm *ModuleManager) UnhookInternalEvent(eventName string, handlerFunc func(sourceModuleName string, event map[string]interface{})) { + mm.moduleMutex.Lock() + defer mm.moduleMutex.Unlock() + var handlers = mm.internalEventHandlers[eventName] + if handlers != nil { + for i, v := range handlers { + if reflect.ValueOf(v).Pointer() == reflect.ValueOf(handlerFunc).Pointer() { + handlers[i] = handlers[len(handlers)-1] + handlers[len(handlers)-1] = nil + handlers = handlers[:len(handlers)-1] + mm.internalEventHandlers[eventName] = handlers + break + } + } + } +} + +func (mm *ModuleManager) FireEvent(sourceModuleName string, eventName string, eventPayload map[string]interface{}) { + router.RouteInternalEvent(sourceModuleName, eventName, eventPayload) +} + +func (mm *ModuleManager) GenerateToken(entityID, deviceID string, tokenExpireTimeDuration time.Duration) (string, error) { + token, err := authManager.CreateNewToken(entityID, deviceID, tokenExpireTimeDuration) + return token, err +} diff --git a/internal/origin_c2s.go b/internal/origin_c2s.go new file mode 100644 index 0000000..7e32764 --- /dev/null +++ b/internal/origin_c2s.go @@ -0,0 +1,17 @@ +package internal + +import ( + "github.com/ChronosX88/zirconium/internal/models" + "github.com/gorilla/websocket" +) + +type OriginC2S struct { + wsConn *websocket.Conn + connID string + entityID *models.EntityID + deviceID *string +} + +func (o *OriginC2S) Send(message models.BaseMessage) error { + return o.wsConn.WriteJSON(message) +} diff --git a/internal/router.go b/internal/router.go new file mode 100644 index 0000000..e6f88ae --- /dev/null +++ b/internal/router.go @@ -0,0 +1,69 @@ +package internal + +import ( + "github.com/ChronosX88/zirconium/internal/models" + "github.com/google/logger" +) + +type Router struct { + moduleManager *ModuleManager + connections []*OriginC2S +} + +func NewRouter() (*Router, error) { + mm, err := NewModuleManager() + if err != nil { + return nil, err + } + r := &Router{ + moduleManager: mm, + } + return r, nil +} + +func (r *Router) RouteMessage(origin *OriginC2S, message models.BaseMessage) { + handlers := r.moduleManager.c2sMessageHandlers[message.MessageType] + if handlers != nil { + for _, v := range handlers { + if !v.AnonymousAllowed { + var entityID, deviceID string + var isValid bool + var err error + if message.AuthToken != "" { + isValid, entityID, deviceID, err = authManager.ValidateToken(message.AuthToken) + if err != nil || !isValid { + logger.Warningf("Connection %s isn't authorized", origin.connID) + msg := PrepareMessageUnauthorized(message) + origin.Send(msg) + } + } else { + logger.Warningf("Connection %s isn't authorized", origin.connID) + + msg := PrepareMessageUnauthorized(message) + origin.Send(msg) + } + + if origin.entityID == nil { + origin.entityID = models.NewEntityID(entityID) + } + if origin.deviceID == nil { + origin.deviceID = &deviceID + } + } + go v.HandlerFunc(origin, message) + } + } else { + logger.Infof("Drop message with type %s because server hasn't proper handlers", message.MessageType) + } +} + +func (r *Router) RouteInternalEvent(sourceModuleName string, eventName string, eventPayload map[string]interface{}) { + handlers := r.moduleManager.internalEventHandlers[eventName] + if handlers != nil { + for _, v := range handlers { + go v(sourceModuleName, eventPayload) + } + } else { + logger.Infof("Drop event %s because server hasn't proper handlers", eventName) + } +} diff --git a/internal/utils.go b/internal/utils.go new file mode 100644 index 0000000..98270f4 --- /dev/null +++ b/internal/utils.go @@ -0,0 +1,50 @@ +package internal + +import ( + "crypto/rand" + "reflect" + + "github.com/ChronosX88/zirconium/internal/models" +) + +func GenRandomBytes(size int) (blk []byte, err error) { + blk = make([]byte, size) + _, err = rand.Read(blk) + return +} + +func StructToMap(item interface{}) map[string]interface{} { + res := map[string]interface{}{} + if item == nil { + return res + } + v := reflect.TypeOf(item) + reflectValue := reflect.ValueOf(item) + reflectValue = reflect.Indirect(reflectValue) + + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + for i := 0; i < v.NumField(); i++ { + tag := v.Field(i).Tag.Get("json") + field := reflectValue.Field(i).Interface() + if tag != "" && tag != "-" { + if v.Field(i).Type.Kind() == reflect.Struct { + res[tag] = StructToMap(field) + } else { + res[tag] = field + } + } + } + return res +} + +func PrepareMessageUnauthorized(msg models.BaseMessage) models.BaseMessage { + protocolError := models.ProtocolError{ + ErrCode: "unauthorized", + ErrText: "Unauthorized access", + ErrPayload: make(map[string]interface{}), + } + errMsg := models.NewBaseMessage(msg.MessageType, serverDomain, msg.From, false, StructToMap(protocolError)) + return errMsg +} diff --git a/plugins/.gitkeep b/plugins/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/shared/module.go b/shared/module.go new file mode 100644 index 0000000..9dd939d --- /dev/null +++ b/shared/module.go @@ -0,0 +1,5 @@ +package shared + +type Module interface { + Initialize() +} diff --git a/shared/plugin_map.go b/shared/plugin_map.go new file mode 100644 index 0000000..3572fa4 --- /dev/null +++ b/shared/plugin_map.go @@ -0,0 +1,5 @@ +package shared + +var Plugins = map[string]Module{ + // Add plugins here +}