mirror of
https://github.com/cadmium-im/zirconium-go.git
synced 2024-11-23 10:52:24 +00:00
Implement connection negotiation and parsing serverpart for EntityID
This commit is contained in:
parent
38d83b44ff
commit
39e08edbc9
@ -44,7 +44,7 @@ func NewAppContext(cfg *Config) *AppContext {
|
|||||||
}
|
}
|
||||||
appContext.authManager = authManager
|
appContext.authManager = authManager
|
||||||
|
|
||||||
sessionManager := NewSessionManager(router)
|
sessionManager := NewSessionManager(router, cfg.ServerDomains)
|
||||||
appContext.sessionManager = sessionManager
|
appContext.sessionManager = sessionManager
|
||||||
|
|
||||||
wss := NewWebsocketServer(cfg, sessionManager)
|
wss := NewWebsocketServer(cfg, sessionManager)
|
||||||
|
@ -36,7 +36,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
|
|||||||
token, claims, err := ah.authManager.HandleSimpleAuth(authRequest.Fields["username"].(string), authRequest.Fields["password"].(string))
|
token, claims, err := ah.authManager.HandleSimpleAuth(authRequest.Fields["username"].(string), authRequest.Fields["password"].(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if mongo.ErrNoDocuments == err {
|
if mongo.ErrNoDocuments == err {
|
||||||
msg := utils.PrepareErrorMessage(message, "auth-failed", "invalid username", ah.serverID)
|
msg := utils.PrepareErrorMessage(message, "auth-failed", "invalid username", message.To)
|
||||||
_ = s.Send(msg)
|
_ = s.Send(msg)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -50,7 +50,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
|
|||||||
DeviceID: claims.DeviceID,
|
DeviceID: claims.DeviceID,
|
||||||
}
|
}
|
||||||
payload := structs.Map(ar)
|
payload := structs.Map(ar)
|
||||||
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, nil, true, payload)
|
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, "", true, payload)
|
||||||
_ = s.Send(msg)
|
_ = s.Send(msg)
|
||||||
s.Claims = claims
|
s.Claims = claims
|
||||||
return
|
return
|
||||||
@ -65,7 +65,7 @@ func (ah *AuthHandler) HandleMessage(s *Session, message models.BaseMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.Claims = claims
|
s.Claims = claims
|
||||||
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, nil, true, nil)
|
msg := models.NewBaseMessage(message.ID, message.MessageType, ah.serverID, "", true, nil)
|
||||||
_ = s.Send(msg)
|
_ = s.Send(msg)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -4,13 +4,13 @@ package models
|
|||||||
type BaseMessage struct {
|
type BaseMessage struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
MessageType string `json:"type"`
|
MessageType string `json:"type"`
|
||||||
From string `json:"from"`
|
From string `json:"from,omitempty"`
|
||||||
To []string `json:"to,omitempty"`
|
To string `json:"to,omitempty"`
|
||||||
Ok bool `json:"ok"`
|
Ok bool `json:"ok"`
|
||||||
Payload map[string]interface{} `json:"payload"`
|
Payload map[string]interface{} `json:"payload"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBaseMessage(id, messageType, from string, to []string, ok bool, payload map[string]interface{}) BaseMessage {
|
func NewBaseMessage(id, messageType, from string, to string, ok bool, payload map[string]interface{}) BaseMessage {
|
||||||
return BaseMessage{
|
return BaseMessage{
|
||||||
ID: id,
|
ID: id,
|
||||||
MessageType: messageType,
|
MessageType: messageType,
|
||||||
@ -20,3 +20,14 @@ func NewBaseMessage(id, messageType, from string, to []string, ok bool, payload
|
|||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewEmptyBaseMessage() BaseMessage {
|
||||||
|
return BaseMessage{
|
||||||
|
ID: "",
|
||||||
|
MessageType: "",
|
||||||
|
From: "",
|
||||||
|
To: "",
|
||||||
|
Ok: true,
|
||||||
|
Payload: map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -16,16 +16,26 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type EntityID struct {
|
type EntityID struct {
|
||||||
Type EntityIDType
|
Type EntityIDType
|
||||||
LocalPart string
|
LocalPart string
|
||||||
ServerPart string
|
ServerPart string
|
||||||
Attr string
|
Attr string
|
||||||
|
OnlyServerPart bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEntityIDFromString(entityID string) (*EntityID, error) {
|
func NewEntityIDFromString(entityID string) (*EntityID, error) {
|
||||||
eid := &EntityID{}
|
eid := &EntityID{}
|
||||||
typ := string(entityID[0])
|
typ := string(entityID[0])
|
||||||
withAttr := false
|
withAttr := false
|
||||||
|
localAndServerPart := strings.Split(entityID, "@")
|
||||||
|
if localAndServerPart[0] == "" {
|
||||||
|
localAndServerPart = localAndServerPart[1:]
|
||||||
|
}
|
||||||
|
if len(localAndServerPart) == 1 {
|
||||||
|
eid.ServerPart = localAndServerPart[0]
|
||||||
|
eid.OnlyServerPart = true
|
||||||
|
return eid, nil
|
||||||
|
}
|
||||||
|
|
||||||
switch EntityIDType(typ) {
|
switch EntityIDType(typ) {
|
||||||
case UsernameType:
|
case UsernameType:
|
||||||
@ -47,10 +57,6 @@ func NewEntityIDFromString(entityID string) (*EntityID, error) {
|
|||||||
return nil, fmt.Errorf("invalid entity id type: %s", typ)
|
return nil, fmt.Errorf("invalid entity id type: %s", typ)
|
||||||
}
|
}
|
||||||
|
|
||||||
localAndServerPart := strings.Split(entityID, "@")
|
|
||||||
if len(localAndServerPart) == 3 && localAndServerPart[0] == "" {
|
|
||||||
localAndServerPart = localAndServerPart[0:]
|
|
||||||
}
|
|
||||||
if !withAttr {
|
if !withAttr {
|
||||||
eid.LocalPart = localAndServerPart[0]
|
eid.LocalPart = localAndServerPart[0]
|
||||||
eid.ServerPart = localAndServerPart[1]
|
eid.ServerPart = localAndServerPart[1]
|
||||||
|
@ -37,3 +37,15 @@ func TestNewEntityIDFromStringWithEmailAttr(t *testing.T) {
|
|||||||
t.Fatal(eid.String())
|
t.Fatal(eid.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewEntityIDFromStringWithOnlyServerPart(t *testing.T) {
|
||||||
|
str := "cadmium.org"
|
||||||
|
|
||||||
|
eid, err := NewEntityIDFromString(str)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error must be null")
|
||||||
|
}
|
||||||
|
if !eid.OnlyServerPart && eid.ServerPart != "cadmium.org" {
|
||||||
|
t.Fatal(eid.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,6 +2,6 @@ package models
|
|||||||
|
|
||||||
type ProtocolError struct {
|
type ProtocolError struct {
|
||||||
ErrCode string `structs:"code"`
|
ErrCode string `structs:"code"`
|
||||||
ErrText string `structs:"text"`
|
ErrText string `structs:"text,omitempty"`
|
||||||
ErrPayload map[string]interface{} `structs:"payload,omitempty"`
|
ErrPayload map[string]interface{} `structs:"payload,omitempty"`
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ func (r *Router) RouteMessage(origin *Session, message models.BaseMessage) {
|
|||||||
if origin.Claims == nil {
|
if origin.Claims == nil {
|
||||||
logger.Warningf("Connection %s isn't authorized", origin.connID)
|
logger.Warningf("Connection %s isn't authorized", origin.connID)
|
||||||
|
|
||||||
msg := utils.PrepareMessageUnauthorized(message, r.appContext.cfg.ServerDomains[0]) // fixme: domain
|
msg := utils.PrepareMessageUnauthorized(message, message.To) // fixme: domain
|
||||||
_ = origin.Send(msg)
|
_ = origin.Send(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -46,7 +46,7 @@ func (r *Router) RouteMessage(origin *Session, message models.BaseMessage) {
|
|||||||
ErrText: "Server doesn't implement message type " + message.MessageType,
|
ErrText: "Server doesn't implement message type " + message.MessageType,
|
||||||
ErrPayload: make(map[string]interface{}),
|
ErrPayload: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
errMsg := models.NewBaseMessage(message.ID, message.MessageType, r.appContext.cfg.ServerID, []string{message.From}, false, structs.Map(protocolError))
|
errMsg := models.NewBaseMessage(message.ID, message.MessageType, message.To, message.From, false, structs.Map(protocolError))
|
||||||
logger.Infof("Drop message with type %s because server hasn't proper handlers", message.MessageType)
|
logger.Infof("Drop message with type %s because server hasn't proper handlers", message.MessageType)
|
||||||
_ = origin.Send(errMsg)
|
_ = origin.Send(errMsg)
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,12 @@ func (s *Session) Send(message models.BaseMessage) error {
|
|||||||
return s.wsConn.WriteJSON(message)
|
return s.wsConn.WriteJSON(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Session) Receive() (models.BaseMessage, error) {
|
||||||
|
var msg models.BaseMessage
|
||||||
|
err := s.wsConn.ReadJSON(&msg)
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Session) Close() error {
|
func (s *Session) Close() error {
|
||||||
return s.wsConn.Close()
|
return s.wsConn.Close()
|
||||||
}
|
}
|
||||||
|
@ -2,42 +2,91 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/cadmium-im/zirconium-go/core/models"
|
"github.com/cadmium-im/zirconium-go/core/models"
|
||||||
|
"github.com/cadmium-im/zirconium-go/core/utils"
|
||||||
"github.com/google/logger"
|
"github.com/google/logger"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SessionManager struct {
|
type SessionManager struct {
|
||||||
|
domains []string
|
||||||
router *Router
|
router *Router
|
||||||
connections map[string]*Session
|
connections map[string]*Session
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionManager(r *Router) *SessionManager {
|
func NewSessionManager(r *Router, domains []string) *SessionManager {
|
||||||
return &SessionManager{
|
return &SessionManager{
|
||||||
|
domains: domains,
|
||||||
router: r,
|
router: r,
|
||||||
connections: make(map[string]*Session),
|
connections: make(map[string]*Session),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *SessionManager) HandleNewConnection(wsocket *websocket.Conn) {
|
func (sm *SessionManager) HandleNewConnection(wsocket *websocket.Conn) {
|
||||||
randomUUID := uuid.New().String()
|
randomUUID := uuid.New().String()
|
||||||
o := &Session{
|
s := &Session{
|
||||||
wsConn: wsocket,
|
wsConn: wsocket,
|
||||||
connID: randomUUID,
|
connID: randomUUID,
|
||||||
}
|
}
|
||||||
ch.connections[o.connID] = o
|
|
||||||
|
msg, err := s.Receive()
|
||||||
|
if err != nil {
|
||||||
|
logger.Infof("Error occurred when tried to receive first message! Closing connection... (%s)", err.Error())
|
||||||
|
_ = s.wsConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.MessageType != "urn:cadmium:connection:open" {
|
||||||
|
msg.MessageType = "urn:cadmium:connection"
|
||||||
|
emsg := utils.PrepareErrorMessage(msg, "invalid-conn-negotiation", "", "")
|
||||||
|
s.Send(emsg)
|
||||||
|
s.wsConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
eid, err := models.NewEntityIDFromString(msg.To)
|
||||||
|
if err != nil {
|
||||||
|
emsg := utils.PrepareErrorMessage(msg, "invalid-eid", "", "")
|
||||||
|
s.Send(emsg)
|
||||||
|
s.wsConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if eid.OnlyServerPart && !utils.InStringArray(eid.ServerPart, sm.domains) {
|
||||||
|
emsg := utils.PrepareErrorMessage(msg, "unknown-host", "", "")
|
||||||
|
s.Send(emsg)
|
||||||
|
s.wsConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
replyMessage := models.NewEmptyBaseMessage()
|
||||||
|
replyMessage.ID = msg.ID
|
||||||
|
replyMessage.From = msg.To
|
||||||
|
replyMessage.MessageType = "urn:cadmium:connection:open"
|
||||||
|
replyMessage.Payload["id"] = randomUUID
|
||||||
|
s.Send(replyMessage)
|
||||||
|
|
||||||
|
sm.connections[s.connID] = s
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
var msg models.BaseMessage
|
msg, err := s.Receive()
|
||||||
err := o.wsConn.ReadJSON(&msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
delete(ch.connections, o.connID)
|
delete(sm.connections, s.connID)
|
||||||
_ = o.wsConn.Close()
|
_ = s.wsConn.Close()
|
||||||
logger.Infof("Connection %s was closed. (Reason: %s)", o.connID, err.Error())
|
logger.Infof("Connection %s was closed. (Reason: %s)", s.connID, err.Error())
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
ch.router.RouteMessage(o, msg)
|
eid, err := models.NewEntityIDFromString(msg.To)
|
||||||
|
if err != nil {
|
||||||
|
emsg := utils.PrepareErrorMessage(msg, "invalid-eid", "", "")
|
||||||
|
s.Send(emsg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if eid.LocalPart == "" && utils.InStringArray(eid.ServerPart, sm.domains) {
|
||||||
|
emsg := utils.PrepareErrorMessage(msg, "unknown-host", "", "")
|
||||||
|
s.Send(emsg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sm.router.RouteMessage(s, msg)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
logger.Infof("Connection %s was created", o.connID)
|
logger.Infof("Connection %s was created", s.connID)
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ func PrepareMessageUnauthorized(msg models.BaseMessage, serverDomain string) mod
|
|||||||
ErrText: "Unauthorized access",
|
ErrText: "Unauthorized access",
|
||||||
ErrPayload: make(map[string]interface{}),
|
ErrPayload: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverDomain, []string{msg.From}, false, structs.Map(protocolError))
|
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverDomain, msg.From, false, structs.Map(protocolError))
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,11 +38,7 @@ func PrepareMessageInternalServerError(msg models.BaseMessage, err error, server
|
|||||||
ErrText: err.Error(),
|
ErrText: err.Error(),
|
||||||
ErrPayload: nil,
|
ErrPayload: nil,
|
||||||
}
|
}
|
||||||
var to []string
|
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, msg.From, false, structs.Map(protocolError))
|
||||||
if msg.From != "" {
|
|
||||||
to = append(to, msg.From)
|
|
||||||
}
|
|
||||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, to, false, structs.Map(protocolError))
|
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,11 +48,7 @@ func PrepareErrorMessage(msg models.BaseMessage, errorType string, errorText str
|
|||||||
ErrText: errorText,
|
ErrText: errorText,
|
||||||
ErrPayload: nil,
|
ErrPayload: nil,
|
||||||
}
|
}
|
||||||
var to []string
|
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, msg.From, false, structs.Map(protocolError))
|
||||||
if msg.From != "" {
|
|
||||||
to = append(to, msg.From)
|
|
||||||
}
|
|
||||||
errMsg := models.NewBaseMessage(msg.ID, msg.MessageType, serverID, to, false, structs.Map(protocolError))
|
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,3 +81,12 @@ func IsCollectionExists(ctx context.Context, db *mongo.Database, collectionName
|
|||||||
}
|
}
|
||||||
return isExists, nil
|
return isExists, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InStringArray(val string, array []string) (ok bool) {
|
||||||
|
for i := range array {
|
||||||
|
if ok = array[i] == val; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user