mirror of
https://github.com/cadmium-im/zirconium-go.git
synced 2024-11-23 10:52:24 +00:00
Implement connection negotiation and parsing serverpart for EntityID
This commit is contained in:
parent
38d83b44ff
commit
39e08edbc9
@ -44,7 +44,7 @@ func NewAppContext(cfg *Config) *AppContext {
|
||||
}
|
||||
appContext.authManager = authManager
|
||||
|
||||
sessionManager := NewSessionManager(router)
|
||||
sessionManager := NewSessionManager(router, cfg.ServerDomains)
|
||||
appContext.sessionManager = sessionManager
|
||||
|
||||
wss := NewWebsocketServer(cfg, sessionManager)
|
||||
|
@ -36,7 +36,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
|
||||
token, claims, err := ah.authManager.HandleSimpleAuth(authRequest.Fields["username"].(string), authRequest.Fields["password"].(string))
|
||||
if err != nil {
|
||||
if mongo.ErrNoDocuments == err {
|
||||
msg := utils.PrepareErrorMessage(message, "auth-failed", "invalid username", ah.serverID)
|
||||
msg := utils.PrepareErrorMessage(message, "auth-failed", "invalid username", message.To)
|
||||
_ = s.Send(msg)
|
||||
return
|
||||
}
|
||||
@ -50,7 +50,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
|
||||
DeviceID: claims.DeviceID,
|
||||
}
|
||||
payload := structs.Map(ar)
|
||||
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, nil, true, payload)
|
||||
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, "", true, payload)
|
||||
_ = s.Send(msg)
|
||||
s.Claims = claims
|
||||
return
|
||||
@ -65,7 +65,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
|
||||
return
|
||||
}
|
||||
s.Claims = claims
|
||||
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, nil, true, nil)
|
||||
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, "", true, nil)
|
||||
_ = s.Send(msg)
|
||||
return
|
||||
}
|
||||
|
@ -4,13 +4,13 @@ package models
|
||||
type BaseMessage struct {
|
||||
ID string `json:"id"`
|
||||
MessageType string `json:"type"`
|
||||
From string `json:"from"`
|
||||
To []string `json:"to,omitempty"`
|
||||
From string `json:"from,omitempty"`
|
||||
To string `json:"to,omitempty"`
|
||||
Ok bool `json:"ok"`
|
||||
Payload map[string]interface{} `json:"payload"`
|
||||
}
|
||||
|
||||
func NewBaseMessage(id, messageType, from string, to []string, ok bool, payload map[string]interface{}) BaseMessage {
|
||||
func NewBaseMessage(id, messageType, from string, to string, ok bool, payload map[string]interface{}) BaseMessage {
|
||||
return BaseMessage{
|
||||
ID: id,
|
||||
MessageType: messageType,
|
||||
@ -20,3 +20,14 @@ func NewBaseMessage(id, messageType, from string, to []string, ok bool, payload
|
||||
Payload: payload,
|
||||
}
|
||||
}
|
||||
|
||||
func NewEmptyBaseMessage() BaseMessage {
|
||||
return BaseMessage{
|
||||
ID: "",
|
||||
MessageType: "",
|
||||
From: "",
|
||||
To: "",
|
||||
Ok: true,
|
||||
Payload: map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
@ -20,12 +20,22 @@ type EntityID struct {
|
||||
LocalPart string
|
||||
ServerPart string
|
||||
Attr string
|
||||
OnlyServerPart bool
|
||||
}
|
||||
|
||||
func NewEntityIDFromString(entityID string) (*EntityID, error) {
|
||||
eid := &EntityID{}
|
||||
typ := string(entityID[0])
|
||||
withAttr := false
|
||||
localAndServerPart := strings.Split(entityID, "@")
|
||||
if localAndServerPart[0] == "" {
|
||||
localAndServerPart = localAndServerPart[1:]
|
||||
}
|
||||
if len(localAndServerPart) == 1 {
|
||||
eid.ServerPart = localAndServerPart[0]
|
||||
eid.OnlyServerPart = true
|
||||
return eid, nil
|
||||
}
|
||||
|
||||
switch EntityIDType(typ) {
|
||||
case UsernameType:
|
||||
@ -47,10 +57,6 @@ func NewEntityIDFromString(entityID string) (*EntityID, error) {
|
||||
return nil, fmt.Errorf("invalid entity id type: %s", typ)
|
||||
}
|
||||
|
||||
localAndServerPart := strings.Split(entityID, "@")
|
||||
if len(localAndServerPart) == 3 && localAndServerPart[0] == "" {
|
||||
localAndServerPart = localAndServerPart[0:]
|
||||
}
|
||||
if !withAttr {
|
||||
eid.LocalPart = localAndServerPart[0]
|
||||
eid.ServerPart = localAndServerPart[1]
|
||||
|
@ -37,3 +37,15 @@ func TestNewEntityIDFromStringWithEmailAttr(t *testing.T) {
|
||||
t.Fatal(eid.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEntityIDFromStringWithOnlyServerPart(t *testing.T) {
|
||||
str := "cadmium.org"
|
||||
|
||||
eid, err := NewEntityIDFromString(str)
|
||||
if err != nil {
|
||||
t.Fatal("error must be null")
|
||||
}
|
||||
if !eid.OnlyServerPart && eid.ServerPart != "cadmium.org" {
|
||||
t.Fatal(eid.String())
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,6 @@ package models
|
||||
|
||||
type ProtocolError struct {
|
||||
ErrCode string `structs:"code"`
|
||||
ErrText string `structs:"text"`
|
||||
ErrText string `structs:"text,omitempty"`
|
||||
ErrPayload map[string]interface{} `structs:"payload,omitempty"`
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func (r *Router) RouteMessage(origin *Session, message models.BaseMessage) {
|
||||
if origin.Claims == nil {
|
||||
logger.Warningf("Connection %s isn't authorized", origin.connID)
|
||||
|
||||
msg := utils.PrepareMessageUnauthorized(message, r.appContext.cfg.ServerDomains[0]) // fixme: domain
|
||||
msg := utils.PrepareMessageUnauthorized(message, message.To) // fixme: domain
|
||||
_ = origin.Send(msg)
|
||||
}
|
||||
}
|
||||
@ -46,7 +46,7 @@ func (r *Router) RouteMessage(origin *Session, message models.BaseMessage) {
|
||||
ErrText: "Server doesn't implement message type " + message.MessageType,
|
||||
ErrPayload: make(map[string]interface{}),
|
||||
}
|
||||
errMsg := models.NewBaseMessage(message.ID, message.MessageType, r.appContext.cfg.ServerID, []string{message.From}, false, structs.Map(protocolError))
|
||||
errMsg := models.NewBaseMessage(message.ID, message.MessageType, message.To, message.From, false, structs.Map(protocolError))
|
||||
logger.Infof("Drop message with type %s because server hasn't proper handlers", message.MessageType)
|
||||
_ = origin.Send(errMsg)
|
||||
}
|
||||
|
@ -15,6 +15,12 @@ func (s *Session) Send(message models.BaseMessage) error {
|
||||
return s.wsConn.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (s *Session) Receive() (models.BaseMessage, error) {
|
||||
var msg models.BaseMessage
|
||||
err := s.wsConn.ReadJSON(&msg)
|
||||
return msg, err
|
||||
}
|
||||
|
||||
func (s *Session) Close() error {
|
||||
return s.wsConn.Close()
|
||||
}
|
||||
|
@ -2,42 +2,91 @@ package core
|
||||
|
||||
import (
|
||||
"github.com/cadmium-im/zirconium-go/core/models"
|
||||
"github.com/cadmium-im/zirconium-go/core/utils"
|
||||
"github.com/google/logger"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type SessionManager struct {
|
||||
domains []string
|
||||
router *Router
|
||||
connections map[string]*Session
|
||||
}
|
||||
|
||||
func NewSessionManager(r *Router) *SessionManager {
|
||||
func NewSessionManager(r *Router, domains []string) *SessionManager {
|
||||
return &SessionManager{
|
||||
domains: domains,
|
||||
router: r,
|
||||
connections: make(map[string]*Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *SessionManager) HandleNewConnection(wsocket *websocket.Conn) {
|
||||
func (sm *SessionManager) HandleNewConnection(wsocket *websocket.Conn) {
|
||||
randomUUID := uuid.New().String()
|
||||
o := &Session{
|
||||
s := &Session{
|
||||
wsConn: wsocket,
|
||||
connID: randomUUID,
|
||||
}
|
||||
ch.connections[o.connID] = o
|
||||
|
||||
msg, err := s.Receive()
|
||||
if err != nil {
|
||||
logger.Infof("Error occurred when tried to receive first message! Closing connection... (%s)", err.Error())
|
||||
_ = s.wsConn.Close()
|
||||
return
|
||||
}
|
||||
if msg.MessageType != "urn:cadmium:connection:open" {
|
||||
msg.MessageType = "urn:cadmium:connection"
|
||||
emsg := utils.PrepareErrorMessage(msg, "invalid-conn-negotiation", "", "")
|
||||
s.Send(emsg)
|
||||
s.wsConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
eid, err := models.NewEntityIDFromString(msg.To)
|
||||
if err != nil {
|
||||
emsg := utils.PrepareErrorMessage(msg, "invalid-eid", "", "")
|
||||
s.Send(emsg)
|
||||
s.wsConn.Close()
|
||||
return
|
||||
}
|
||||
if eid.OnlyServerPart && !utils.InStringArray(eid.ServerPart, sm.domains) {
|
||||
emsg := utils.PrepareErrorMessage(msg, "unknown-host", "", "")
|
||||
s.Send(emsg)
|
||||
s.wsConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
replyMessage := models.NewEmptyBaseMessage()
|
||||
replyMessage.ID = msg.ID
|
||||
replyMessage.From = msg.To
|
||||
replyMessage.MessageType = "urn:cadmium:connection:open"
|
||||
replyMessage.Payload["id"] = randomUUID
|
||||
s.Send(replyMessage)
|
||||
|
||||
sm.connections[s.connID] = s
|
||||
go func() {
|
||||
for {
|
||||
var msg models.BaseMessage
|
||||
err := o.wsConn.ReadJSON(&msg)
|
||||
msg, err := s.Receive()
|
||||
if err != nil {
|
||||
delete(ch.connections, o.connID)
|
||||
_ = o.wsConn.Close()
|
||||
logger.Infof("Connection %s was closed. (Reason: %s)", o.connID, err.Error())
|
||||
delete(sm.connections, s.connID)
|
||||
_ = s.wsConn.Close()
|
||||
logger.Infof("Connection %s was closed. (Reason: %s)", s.connID, err.Error())
|
||||
break
|
||||
}
|
||||
ch.router.RouteMessage(o, msg)
|
||||
eid, err := models.NewEntityIDFromString(msg.To)
|
||||
if err != nil {
|
||||
emsg := utils.PrepareErrorMessage(msg, "invalid-eid", "", "")
|
||||
s.Send(emsg)
|
||||
continue
|
||||
}
|
||||
if eid.LocalPart == "" && utils.InStringArray(eid.ServerPart, sm.domains) {
|
||||
emsg := utils.PrepareErrorMessage(msg, "unknown-host", "", "")
|
||||
s.Send(emsg)
|
||||
continue
|
||||
}
|
||||
sm.router.RouteMessage(s, msg)
|
||||
}
|
||||
}()
|
||||
logger.Infof("Connection %s was created", o.connID)
|
||||
logger.Infof("Connection %s was created", s.connID)
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func PrepareMessageUnauthorized(msg models.BaseMessage, serverDomain string) mod
|
||||
ErrText: "Unauthorized access",
|
||||
ErrPayload: make(map[string]interface{}),
|
||||
}
|
||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverDomain, []string{msg.From}, false, structs.Map(protocolError))
|
||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverDomain, msg.From, false, structs.Map(protocolError))
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@ -38,11 +38,7 @@ func PrepareMessageInternalServerError(msg models.BaseMessage, err error, server
|
||||
ErrText: err.Error(),
|
||||
ErrPayload: nil,
|
||||
}
|
||||
var to []string
|
||||
if msg.From != "" {
|
||||
to = append(to, msg.From)
|
||||
}
|
||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, to, false, structs.Map(protocolError))
|
||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, msg.From, false, structs.Map(protocolError))
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@ -52,11 +48,7 @@ func PrepareErrorMessage(msg models.BaseMessage, errorType string, errorText str
|
||||
ErrText: errorText,
|
||||
ErrPayload: nil,
|
||||
}
|
||||
var to []string
|
||||
if msg.From != "" {
|
||||
to = append(to, msg.From)
|
||||
}
|
||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, to, false, structs.Map(protocolError))
|
||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, msg.From, false, structs.Map(protocolError))
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@ -89,3 +81,12 @@ func IsCollectionExists(ctx context.Context, db *mongo.Database, collectionName
|
||||
}
|
||||
return isExists, nil
|
||||
}
|
||||
|
||||
func InStringArray(val string, array []string) (ok bool) {
|
||||
for i := range array {
|
||||
if ok = array[i] == val; ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user