From 39e08edbc98deadd55a49767f336191beb503c19 Mon Sep 17 00:00:00 2001 From: ChronosX88 Date: Sun, 25 Apr 2021 00:56:20 +0300 Subject: [PATCH] Implement connection negotiation and parsing serverpart for EntityID --- core/app_context.go | 2 +- core/auth_handler.go | 6 +-- core/models/base_message.go | 17 +++++++-- core/models/entity_id.go | 22 +++++++---- core/models/entity_id_test.go | 12 ++++++ core/models/protocol_error.go | 2 +- core/router.go | 4 +- core/session.go | 6 +++ core/session_manager.go | 71 +++++++++++++++++++++++++++++------ core/utils/utils.go | 23 ++++++------ 10 files changed, 125 insertions(+), 40 deletions(-) diff --git a/core/app_context.go b/core/app_context.go index 06c1af0..1b52bfc 100644 --- a/core/app_context.go +++ b/core/app_context.go @@ -44,7 +44,7 @@ func NewAppContext(cfg *Config) *AppContext { } appContext.authManager = authManager - sessionManager := NewSessionManager(router) + sessionManager := NewSessionManager(router, cfg.ServerDomains) appContext.sessionManager = sessionManager wss := NewWebsocketServer(cfg, sessionManager) diff --git a/core/auth_handler.go b/core/auth_handler.go index eadcdea..456c7c2 100644 --- a/core/auth_handler.go +++ b/core/auth_handler.go @@ -36,7 +36,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) { token, claims, err := ah.authManager.HandleSimpleAuth(authRequest.Fields["username"].(string), authRequest.Fields["password"].(string)) if err != nil { if mongo.ErrNoDocuments == err { - msg := utils.PrepareErrorMessage(message, "auth-failed", "invalid username", ah.serverID) + msg := utils.PrepareErrorMessage(message, "auth-failed", "invalid username", message.To) _ = s.Send(msg) return } @@ -50,7 +50,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) { DeviceID: claims.DeviceID, } payload := structs.Map(ar) - msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, nil, true, payload) + msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, "", true, payload) _ = s.Send(msg) s.Claims = claims return @@ -65,7 +65,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) { return } s.Claims = claims - msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, nil, true, nil) + msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, "", true, nil) _ = s.Send(msg) return } diff --git a/core/models/base_message.go b/core/models/base_message.go index c9fad8d..e98c309 100644 --- a/core/models/base_message.go +++ b/core/models/base_message.go @@ -4,13 +4,13 @@ package models type BaseMessage struct { ID string `json:"id"` MessageType string `json:"type"` - From string `json:"from"` - To []string `json:"to,omitempty"` + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` Ok bool `json:"ok"` Payload map[string]interface{} `json:"payload"` } -func NewBaseMessage(id, messageType, from string, to []string, ok bool, payload map[string]interface{}) BaseMessage { +func NewBaseMessage(id, messageType, from string, to string, ok bool, payload map[string]interface{}) BaseMessage { return BaseMessage{ ID: id, MessageType: messageType, @@ -20,3 +20,14 @@ func NewBaseMessage(id, messageType, from string, to []string, ok bool, payload Payload: payload, } } + +func NewEmptyBaseMessage() BaseMessage { + return BaseMessage{ + ID: "", + MessageType: "", + From: "", + To: "", + Ok: true, + Payload: map[string]interface{}{}, + } +} diff --git a/core/models/entity_id.go b/core/models/entity_id.go index cec2044..5881772 100644 --- a/core/models/entity_id.go +++ b/core/models/entity_id.go @@ -16,16 +16,26 @@ const ( ) type EntityID struct { - Type EntityIDType - LocalPart string - ServerPart string - Attr string + Type EntityIDType + LocalPart string + ServerPart string + Attr string + OnlyServerPart bool } func NewEntityIDFromString(entityID string) (*EntityID, error) { eid := &EntityID{} typ := string(entityID[0]) withAttr := false + localAndServerPart := strings.Split(entityID, "@") + if localAndServerPart[0] == "" { + localAndServerPart = localAndServerPart[1:] + } + if len(localAndServerPart) == 1 { + eid.ServerPart = localAndServerPart[0] + eid.OnlyServerPart = true + return eid, nil + } switch EntityIDType(typ) { case UsernameType: @@ -47,10 +57,6 @@ func NewEntityIDFromString(entityID string) (*EntityID, error) { return nil, fmt.Errorf("invalid entity id type: %s", typ) } - localAndServerPart := strings.Split(entityID, "@") - if len(localAndServerPart) == 3 && localAndServerPart[0] == "" { - localAndServerPart = localAndServerPart[0:] - } if !withAttr { eid.LocalPart = localAndServerPart[0] eid.ServerPart = localAndServerPart[1] diff --git a/core/models/entity_id_test.go b/core/models/entity_id_test.go index d03544c..08e2d36 100644 --- a/core/models/entity_id_test.go +++ b/core/models/entity_id_test.go @@ -37,3 +37,15 @@ func TestNewEntityIDFromStringWithEmailAttr(t *testing.T) { t.Fatal(eid.String()) } } + +func TestNewEntityIDFromStringWithOnlyServerPart(t *testing.T) { + str := "cadmium.org" + + eid, err := NewEntityIDFromString(str) + if err != nil { + t.Fatal("error must be null") + } + if !eid.OnlyServerPart && eid.ServerPart != "cadmium.org" { + t.Fatal(eid.String()) + } +} diff --git a/core/models/protocol_error.go b/core/models/protocol_error.go index 7872daa..ae29bfd 100644 --- a/core/models/protocol_error.go +++ b/core/models/protocol_error.go @@ -2,6 +2,6 @@ package models type ProtocolError struct { ErrCode string `structs:"code"` - ErrText string `structs:"text"` + ErrText string `structs:"text,omitempty"` ErrPayload map[string]interface{} `structs:"payload,omitempty"` } diff --git a/core/router.go b/core/router.go index fc34638..cb1845b 100644 --- a/core/router.go +++ b/core/router.go @@ -34,7 +34,7 @@ func (r *Router) RouteMessage(origin *Session, message models.BaseMessage) { if origin.Claims == nil { logger.Warningf("Connection %s isn't authorized", origin.connID) - msg := utils.PrepareMessageUnauthorized(message, r.appContext.cfg.ServerDomains[0]) // fixme: domain + msg := utils.PrepareMessageUnauthorized(message, message.To) // fixme: domain _ = origin.Send(msg) } } @@ -46,7 +46,7 @@ func (r *Router) RouteMessage(origin *Session, message models.BaseMessage) { ErrText: "Server doesn't implement message type " + message.MessageType, ErrPayload: make(map[string]interface{}), } - errMsg := models.NewBaseMessage(message.ID, message.MessageType, r.appContext.cfg.ServerID, []string{message.From}, false, structs.Map(protocolError)) + errMsg := models.NewBaseMessage(message.ID, message.MessageType, message.To, message.From, false, structs.Map(protocolError)) logger.Infof("Drop message with type %s because server hasn't proper handlers", message.MessageType) _ = origin.Send(errMsg) } diff --git a/core/session.go b/core/session.go index 0063189..156a454 100644 --- a/core/session.go +++ b/core/session.go @@ -15,6 +15,12 @@ func (s *Session) Send(message models.BaseMessage) error { return s.wsConn.WriteJSON(message) } +func (s *Session) Receive() (models.BaseMessage, error) { + var msg models.BaseMessage + err := s.wsConn.ReadJSON(&msg) + return msg, err +} + func (s *Session) Close() error { return s.wsConn.Close() } diff --git a/core/session_manager.go b/core/session_manager.go index 15a109f..b9ac8cd 100644 --- a/core/session_manager.go +++ b/core/session_manager.go @@ -2,42 +2,91 @@ package core import ( "github.com/cadmium-im/zirconium-go/core/models" + "github.com/cadmium-im/zirconium-go/core/utils" "github.com/google/logger" "github.com/google/uuid" "github.com/gorilla/websocket" ) type SessionManager struct { + domains []string router *Router connections map[string]*Session } -func NewSessionManager(r *Router) *SessionManager { +func NewSessionManager(r *Router, domains []string) *SessionManager { return &SessionManager{ + domains: domains, router: r, connections: make(map[string]*Session), } } -func (ch *SessionManager) HandleNewConnection(wsocket *websocket.Conn) { +func (sm *SessionManager) HandleNewConnection(wsocket *websocket.Conn) { randomUUID := uuid.New().String() - o := &Session{ + s := &Session{ wsConn: wsocket, connID: randomUUID, } - ch.connections[o.connID] = o + + msg, err := s.Receive() + if err != nil { + logger.Infof("Error occurred when tried to receive first message! Closing connection... (%s)", err.Error()) + _ = s.wsConn.Close() + return + } + if msg.MessageType != "urn:cadmium:connection:open" { + msg.MessageType = "urn:cadmium:connection" + emsg := utils.PrepareErrorMessage(msg, "invalid-conn-negotiation", "", "") + s.Send(emsg) + s.wsConn.Close() + return + } + + eid, err := models.NewEntityIDFromString(msg.To) + if err != nil { + emsg := utils.PrepareErrorMessage(msg, "invalid-eid", "", "") + s.Send(emsg) + s.wsConn.Close() + return + } + if eid.OnlyServerPart && !utils.InStringArray(eid.ServerPart, sm.domains) { + emsg := utils.PrepareErrorMessage(msg, "unknown-host", "", "") + s.Send(emsg) + s.wsConn.Close() + return + } + + replyMessage := models.NewEmptyBaseMessage() + replyMessage.ID = msg.ID + replyMessage.From = msg.To + replyMessage.MessageType = "urn:cadmium:connection:open" + replyMessage.Payload["id"] = randomUUID + s.Send(replyMessage) + + sm.connections[s.connID] = s go func() { for { - var msg models.BaseMessage - err := o.wsConn.ReadJSON(&msg) + msg, err := s.Receive() if err != nil { - delete(ch.connections, o.connID) - _ = o.wsConn.Close() - logger.Infof("Connection %s was closed. (Reason: %s)", o.connID, err.Error()) + delete(sm.connections, s.connID) + _ = s.wsConn.Close() + logger.Infof("Connection %s was closed. (Reason: %s)", s.connID, err.Error()) break } - ch.router.RouteMessage(o, msg) + eid, err := models.NewEntityIDFromString(msg.To) + if err != nil { + emsg := utils.PrepareErrorMessage(msg, "invalid-eid", "", "") + s.Send(emsg) + continue + } + if eid.LocalPart == "" && utils.InStringArray(eid.ServerPart, sm.domains) { + emsg := utils.PrepareErrorMessage(msg, "unknown-host", "", "") + s.Send(emsg) + continue + } + sm.router.RouteMessage(s, msg) } }() - logger.Infof("Connection %s was created", o.connID) + logger.Infof("Connection %s was created", s.connID) } diff --git a/core/utils/utils.go b/core/utils/utils.go index 4f894f8..c0cf185 100644 --- a/core/utils/utils.go +++ b/core/utils/utils.go @@ -28,7 +28,7 @@ func PrepareMessageUnauthorized(msg models.BaseMessage, serverDomain string) mod ErrText: "Unauthorized access", ErrPayload: make(map[string]interface{}), } - errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverDomain, []string{msg.From}, false, structs.Map(protocolError)) + errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverDomain, msg.From, false, structs.Map(protocolError)) return errMsg } @@ -38,11 +38,7 @@ func PrepareMessageInternalServerError(msg models.BaseMessage, err error, server ErrText: err.Error(), ErrPayload: nil, } - var to []string - if msg.From != "" { - to = append(to, msg.From) - } - errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, to, false, structs.Map(protocolError)) + errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, msg.From, false, structs.Map(protocolError)) return errMsg } @@ -52,11 +48,7 @@ func PrepareErrorMessage(msg models.BaseMessage, errorType string, errorText str ErrText: errorText, ErrPayload: nil, } - var to []string - if msg.From != "" { - to = append(to, msg.From) - } - errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, to, false, structs.Map(protocolError)) + errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, msg.From, false, structs.Map(protocolError)) return errMsg } @@ -89,3 +81,12 @@ func IsCollectionExists(ctx context.Context, db *mongo.Database, collectionName } return isExists, nil } + +func InStringArray(val string, array []string) (ok bool) { + for i := range array { + if ok = array[i] == val; ok { + return + } + } + return +}