Implement connection negotiation and parsing serverpart for EntityID

This commit is contained in:
ChronosX88 2021-04-25 00:56:20 +03:00
parent 38d83b44ff
commit 39e08edbc9
Signed by: ChronosXYZ
GPG Key ID: 085A69A82C8C511A
10 changed files with 125 additions and 40 deletions

View File

@ -44,7 +44,7 @@ func NewAppContext(cfg *Config) *AppContext {
}
appContext.authManager = authManager
sessionManager := NewSessionManager(router)
sessionManager := NewSessionManager(router, cfg.ServerDomains)
appContext.sessionManager = sessionManager
wss := NewWebsocketServer(cfg, sessionManager)

View File

@ -36,7 +36,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
token, claims, err := ah.authManager.HandleSimpleAuth(authRequest.Fields["username"].(string), authRequest.Fields["password"].(string))
if err != nil {
if mongo.ErrNoDocuments == err {
msg := utils.PrepareErrorMessage(message, "auth-failed", "invalid username", ah.serverID)
msg := utils.PrepareErrorMessage(message, "auth-failed", "invalid username", message.To)
_ = s.Send(msg)
return
}
@ -50,7 +50,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
DeviceID: claims.DeviceID,
}
payload := structs.Map(ar)
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, nil, true, payload)
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, "", true, payload)
_ = s.Send(msg)
s.Claims = claims
return
@ -65,7 +65,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
return
}
s.Claims = claims
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, nil, true, nil)
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, "", true, nil)
_ = s.Send(msg)
return
}

View File

@ -4,13 +4,13 @@ package models
type BaseMessage struct {
ID string `json:"id"`
MessageType string `json:"type"`
From string `json:"from"`
To []string `json:"to,omitempty"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Ok bool `json:"ok"`
Payload map[string]interface{} `json:"payload"`
}
func NewBaseMessage(id, messageType, from string, to []string, ok bool, payload map[string]interface{}) BaseMessage {
func NewBaseMessage(id, messageType, from string, to string, ok bool, payload map[string]interface{}) BaseMessage {
return BaseMessage{
ID: id,
MessageType: messageType,
@ -20,3 +20,14 @@ func NewBaseMessage(id, messageType, from string, to []string, ok bool, payload
Payload: payload,
}
}
func NewEmptyBaseMessage() BaseMessage {
return BaseMessage{
ID: "",
MessageType: "",
From: "",
To: "",
Ok: true,
Payload: map[string]interface{}{},
}
}

View File

@ -16,16 +16,26 @@ const (
)
type EntityID struct {
Type EntityIDType
LocalPart string
ServerPart string
Attr string
Type EntityIDType
LocalPart string
ServerPart string
Attr string
OnlyServerPart bool
}
func NewEntityIDFromString(entityID string) (*EntityID, error) {
eid := &EntityID{}
typ := string(entityID[0])
withAttr := false
localAndServerPart := strings.Split(entityID, "@")
if localAndServerPart[0] == "" {
localAndServerPart = localAndServerPart[1:]
}
if len(localAndServerPart) == 1 {
eid.ServerPart = localAndServerPart[0]
eid.OnlyServerPart = true
return eid, nil
}
switch EntityIDType(typ) {
case UsernameType:
@ -47,10 +57,6 @@ func NewEntityIDFromString(entityID string) (*EntityID, error) {
return nil, fmt.Errorf("invalid entity id type: %s", typ)
}
localAndServerPart := strings.Split(entityID, "@")
if len(localAndServerPart) == 3 && localAndServerPart[0] == "" {
localAndServerPart = localAndServerPart[0:]
}
if !withAttr {
eid.LocalPart = localAndServerPart[0]
eid.ServerPart = localAndServerPart[1]

View File

@ -37,3 +37,15 @@ func TestNewEntityIDFromStringWithEmailAttr(t *testing.T) {
t.Fatal(eid.String())
}
}
func TestNewEntityIDFromStringWithOnlyServerPart(t *testing.T) {
str := "cadmium.org"
eid, err := NewEntityIDFromString(str)
if err != nil {
t.Fatal("error must be null")
}
if !eid.OnlyServerPart && eid.ServerPart != "cadmium.org" {
t.Fatal(eid.String())
}
}

View File

@ -2,6 +2,6 @@ package models
type ProtocolError struct {
ErrCode string `structs:"code"`
ErrText string `structs:"text"`
ErrText string `structs:"text,omitempty"`
ErrPayload map[string]interface{} `structs:"payload,omitempty"`
}

View File

@ -34,7 +34,7 @@ func (r *Router) RouteMessage(origin *Session, message models.BaseMessage) {
if origin.Claims == nil {
logger.Warningf("Connection %s isn't authorized", origin.connID)
msg := utils.PrepareMessageUnauthorized(message, r.appContext.cfg.ServerDomains[0]) // fixme: domain
msg := utils.PrepareMessageUnauthorized(message, message.To) // fixme: domain
_ = origin.Send(msg)
}
}
@ -46,7 +46,7 @@ func (r *Router) RouteMessage(origin *Session, message models.BaseMessage) {
ErrText: "Server doesn't implement message type " + message.MessageType,
ErrPayload: make(map[string]interface{}),
}
errMsg := models.NewBaseMessage(message.ID, message.MessageType, r.appContext.cfg.ServerID, []string{message.From}, false, structs.Map(protocolError))
errMsg := models.NewBaseMessage(message.ID, message.MessageType, message.To, message.From, false, structs.Map(protocolError))
logger.Infof("Drop message with type %s because server hasn't proper handlers", message.MessageType)
_ = origin.Send(errMsg)
}

View File

@ -15,6 +15,12 @@ func (s *Session) Send(message models.BaseMessage) error {
return s.wsConn.WriteJSON(message)
}
func (s *Session) Receive() (models.BaseMessage, error) {
var msg models.BaseMessage
err := s.wsConn.ReadJSON(&msg)
return msg, err
}
func (s *Session) Close() error {
return s.wsConn.Close()
}

View File

@ -2,42 +2,91 @@ package core
import (
"github.com/cadmium-im/zirconium-go/core/models"
"github.com/cadmium-im/zirconium-go/core/utils"
"github.com/google/logger"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type SessionManager struct {
domains []string
router *Router
connections map[string]*Session
}
func NewSessionManager(r *Router) *SessionManager {
func NewSessionManager(r *Router, domains []string) *SessionManager {
return &SessionManager{
domains: domains,
router: r,
connections: make(map[string]*Session),
}
}
func (ch *SessionManager) HandleNewConnection(wsocket *websocket.Conn) {
func (sm *SessionManager) HandleNewConnection(wsocket *websocket.Conn) {
randomUUID := uuid.New().String()
o := &Session{
s := &Session{
wsConn: wsocket,
connID: randomUUID,
}
ch.connections[o.connID] = o
msg, err := s.Receive()
if err != nil {
logger.Infof("Error occurred when tried to receive first message! Closing connection... (%s)", err.Error())
_ = s.wsConn.Close()
return
}
if msg.MessageType != "urn:cadmium:connection:open" {
msg.MessageType = "urn:cadmium:connection"
emsg := utils.PrepareErrorMessage(msg, "invalid-conn-negotiation", "", "")
s.Send(emsg)
s.wsConn.Close()
return
}
eid, err := models.NewEntityIDFromString(msg.To)
if err != nil {
emsg := utils.PrepareErrorMessage(msg, "invalid-eid", "", "")
s.Send(emsg)
s.wsConn.Close()
return
}
if eid.OnlyServerPart && !utils.InStringArray(eid.ServerPart, sm.domains) {
emsg := utils.PrepareErrorMessage(msg, "unknown-host", "", "")
s.Send(emsg)
s.wsConn.Close()
return
}
replyMessage := models.NewEmptyBaseMessage()
replyMessage.ID = msg.ID
replyMessage.From = msg.To
replyMessage.MessageType = "urn:cadmium:connection:open"
replyMessage.Payload["id"] = randomUUID
s.Send(replyMessage)
sm.connections[s.connID] = s
go func() {
for {
var msg models.BaseMessage
err := o.wsConn.ReadJSON(&msg)
msg, err := s.Receive()
if err != nil {
delete(ch.connections, o.connID)
_ = o.wsConn.Close()
logger.Infof("Connection %s was closed. (Reason: %s)", o.connID, err.Error())
delete(sm.connections, s.connID)
_ = s.wsConn.Close()
logger.Infof("Connection %s was closed. (Reason: %s)", s.connID, err.Error())
break
}
ch.router.RouteMessage(o, msg)
eid, err := models.NewEntityIDFromString(msg.To)
if err != nil {
emsg := utils.PrepareErrorMessage(msg, "invalid-eid", "", "")
s.Send(emsg)
continue
}
if eid.LocalPart == "" && utils.InStringArray(eid.ServerPart, sm.domains) {
emsg := utils.PrepareErrorMessage(msg, "unknown-host", "", "")
s.Send(emsg)
continue
}
sm.router.RouteMessage(s, msg)
}
}()
logger.Infof("Connection %s was created", o.connID)
logger.Infof("Connection %s was created", s.connID)
}

View File

@ -28,7 +28,7 @@ func PrepareMessageUnauthorized(msg models.BaseMessage, serverDomain string) mod
ErrText: "Unauthorized access",
ErrPayload: make(map[string]interface{}),
}
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverDomain, []string{msg.From}, false, structs.Map(protocolError))
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverDomain, msg.From, false, structs.Map(protocolError))
return errMsg
}
@ -38,11 +38,7 @@ func PrepareMessageInternalServerError(msg models.BaseMessage, err error, server
ErrText: err.Error(),
ErrPayload: nil,
}
var to []string
if msg.From != "" {
to = append(to, msg.From)
}
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, to, false, structs.Map(protocolError))
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, msg.From, false, structs.Map(protocolError))
return errMsg
}
@ -52,11 +48,7 @@ func PrepareErrorMessage(msg models.BaseMessage, errorType string, errorText str
ErrText: errorText,
ErrPayload: nil,
}
var to []string
if msg.From != "" {
to = append(to, msg.From)
}
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, to, false, structs.Map(protocolError))
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, msg.From, false, structs.Map(protocolError))
return errMsg
}
@ -89,3 +81,12 @@ func IsCollectionExists(ctx context.Context, db *mongo.Database, collectionName
}
return isExists, nil
}
func InStringArray(val string, array []string) (ok bool) {
for i := range array {
if ok = array[i] == val; ok {
return
}
}
return
}