mirror of
https://github.com/cadmium-im/zirconium-go.git
synced 2024-12-27 11:21:51 +00:00
Implement basic module system, auth manager and connection handler
This commit is contained in:
parent
67fbfee9e3
commit
a5e6fa72ff
@ -4,13 +4,12 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ChronosX88/zirconium/internal"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var clients = make(map[*websocket.Conn]string)
|
||||
var clientsReverse = make(map[string]*websocket.Conn)
|
||||
var connectionHandler = internal.NewConnectionHandler()
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
@ -33,27 +32,5 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// register client
|
||||
clients[ws] = uuid.New().String()
|
||||
clientsReverse[clients[ws]] = ws
|
||||
go readLoop(ws)
|
||||
log.Printf("Connection %s created!", clients[ws])
|
||||
}
|
||||
|
||||
func readLoop(c *websocket.Conn) {
|
||||
for {
|
||||
if _, _, err := c.NextReader(); err != nil {
|
||||
connectionID := clients[c]
|
||||
if connectionID != "" {
|
||||
delete(clients, c)
|
||||
delete(clientsReverse, connectionID)
|
||||
log.Printf("Connection %s closed!", connectionID)
|
||||
} else {
|
||||
log.Println("connection wasn't found")
|
||||
}
|
||||
c.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
connectionHandler.HandleNewConnection(ws)
|
||||
}
|
||||
|
2
go.mod
2
go.mod
@ -3,6 +3,8 @@ module github.com/ChronosX88/zirconium
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/google/logger v1.0.1
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/gorilla/mux v1.7.3
|
||||
github.com/gorilla/websocket v1.4.1
|
||||
|
6
go.sum
6
go.sum
@ -1,6 +1,12 @@
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/google/logger v1.0.1 h1:Jtq7/44yDwUXMaLTYgXFC31zpm6Oku7OI/k4//yVANQ=
|
||||
github.com/google/logger v1.0.1/go.mod h1:w7O8nrRr0xufejBlQMI83MXqRusvREoJdaAxV+CoAB4=
|
||||
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
|
||||
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
79
internal/auth_manager.go
Normal file
79
internal/auth_manager.go
Normal file
@ -0,0 +1,79 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
const (
|
||||
SigningKeyBytesAmount = 4096
|
||||
TokenExpireTimeDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type AuthManager struct {
|
||||
signingKey string // For now it is random bytes string represented in Base64
|
||||
}
|
||||
|
||||
type JWTCustomClaims struct {
|
||||
EntityID string `json:"entityID"`
|
||||
DeviceID string `json:"deviceID"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
func NewAuthManager() (*AuthManager, error) {
|
||||
am := &AuthManager{}
|
||||
bytes, err := GenRandomBytes(SigningKeyBytesAmount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.signingKey = base64.RawStdEncoding.EncodeToString(bytes)
|
||||
return am, nil
|
||||
}
|
||||
|
||||
func (am *AuthManager) CreateNewToken(entityID, deviceID string, tokenExpireTimeDuration time.Duration) (string, error) {
|
||||
timeNow := time.Now()
|
||||
expiringTime := timeNow.Add(tokenExpireTimeDuration)
|
||||
claims := JWTCustomClaims{
|
||||
entityID,
|
||||
deviceID,
|
||||
jwt.StandardClaims{
|
||||
ExpiresAt: time.Date(
|
||||
expiringTime.Year(),
|
||||
expiringTime.Month(),
|
||||
expiringTime.Day(),
|
||||
expiringTime.Hour(),
|
||||
expiringTime.Minute(),
|
||||
expiringTime.Second(),
|
||||
expiringTime.Nanosecond(),
|
||||
time.UTC,
|
||||
).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(am.signingKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
func (am *AuthManager) ValidateToken(tokenString string) (bool, string, string, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Don't forget to validate the alg is what you expect:
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
|
||||
return []byte(am.signingKey), nil
|
||||
})
|
||||
|
||||
if claims, ok := token.Claims.(*JWTCustomClaims); ok && token.Valid {
|
||||
return true, claims.EntityID, claims.DeviceID, nil
|
||||
}
|
||||
return false, "", "", err
|
||||
}
|
39
internal/connection_handler.go
Normal file
39
internal/connection_handler.go
Normal file
@ -0,0 +1,39 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/ChronosX88/zirconium/internal/models"
|
||||
"github.com/google/logger"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type ConnectionHandler struct {
|
||||
connections map[string]*OriginC2S
|
||||
}
|
||||
|
||||
func NewConnectionHandler() *ConnectionHandler {
|
||||
return &ConnectionHandler{}
|
||||
}
|
||||
|
||||
func (ch *ConnectionHandler) HandleNewConnection(wsocket *websocket.Conn) {
|
||||
uuid, _ := uuid.NewRandom()
|
||||
uuidStr := uuid.String()
|
||||
o := &OriginC2S{
|
||||
wsConn: wsocket,
|
||||
connID: uuidStr,
|
||||
}
|
||||
ch.connections[o.connID] = o
|
||||
go func() {
|
||||
for {
|
||||
var msg models.BaseMessage
|
||||
err := o.wsConn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
delete(ch.connections, o.connID)
|
||||
o.wsConn.Close()
|
||||
logger.Infof("Connection %s was closed. (Reason: %s)", o.connID, err.Error())
|
||||
break
|
||||
}
|
||||
router.RouteMessage(o, msg)
|
||||
}
|
||||
}()
|
||||
}
|
36
internal/globals.go
Normal file
36
internal/globals.go
Normal file
@ -0,0 +1,36 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/ChronosX88/zirconium/shared"
|
||||
"github.com/google/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ModuleMgr *ModuleManager
|
||||
router *Router
|
||||
authManager *AuthManager
|
||||
serverDomain string
|
||||
)
|
||||
|
||||
func InitializeContext(sDomain string) {
|
||||
var err error
|
||||
ModuleMgr, err = NewModuleManager()
|
||||
if err != nil {
|
||||
logger.Fatalf("Unable to initialize module manager: %s", err.Error())
|
||||
}
|
||||
|
||||
router, err = NewRouter()
|
||||
if err != nil {
|
||||
logger.Fatalf("Unable to initialize router: %s", err.Error())
|
||||
}
|
||||
|
||||
authManager, err = NewAuthManager()
|
||||
if err != nil {
|
||||
logger.Fatalf("Unable to initialize authentication manager: %s", err.Error())
|
||||
}
|
||||
serverDomain = sDomain
|
||||
|
||||
for _, v := range shared.Plugins {
|
||||
go v.Initialize() // Initialize provided plugins
|
||||
}
|
||||
}
|
@ -1 +0,0 @@
|
||||
package internal
|
27
internal/models/base_message.go
Normal file
27
internal/models/base_message.go
Normal file
@ -0,0 +1,27 @@
|
||||
package models
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
// BaseMessage is a basic message model, basis of the whole protocol. It is used for a very easy protocol extension process.
|
||||
type BaseMessage struct {
|
||||
ID string `json:"id"`
|
||||
MessageType string `json:"type"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
Ok bool `json:"ok"`
|
||||
AuthToken string `json:"authToken"`
|
||||
Payload map[string]interface{} `json:"payload"`
|
||||
}
|
||||
|
||||
func NewBaseMessage(messageType string, from string, to string, ok bool, payload map[string]interface{}) BaseMessage {
|
||||
uuid, _ := uuid.NewRandom()
|
||||
uuidStr := uuid.String()
|
||||
return BaseMessage{
|
||||
ID: uuidStr,
|
||||
MessageType: messageType,
|
||||
From: from,
|
||||
To: to,
|
||||
Ok: ok,
|
||||
Payload: payload,
|
||||
}
|
||||
}
|
50
internal/models/entity_id.go
Normal file
50
internal/models/entity_id.go
Normal file
@ -0,0 +1,50 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EntityIDType string
|
||||
|
||||
const (
|
||||
UsernameType EntityIDType = "@"
|
||||
RoomAliasType EntityIDType = "#"
|
||||
RoomIDType EntityIDType = "!"
|
||||
)
|
||||
|
||||
type EntityID struct {
|
||||
EntityIDType EntityIDType
|
||||
LocalPart string
|
||||
ServerPart string
|
||||
}
|
||||
|
||||
func NewEntityID(entityID string) *EntityID {
|
||||
eID := &EntityID{}
|
||||
switch EntityIDType(string(entityID[0])) {
|
||||
case UsernameType:
|
||||
{
|
||||
eID.EntityIDType = UsernameType
|
||||
}
|
||||
case RoomAliasType:
|
||||
{
|
||||
eID.EntityIDType = RoomAliasType
|
||||
}
|
||||
case RoomIDType:
|
||||
{
|
||||
eID.EntityIDType = RoomIDType
|
||||
}
|
||||
}
|
||||
localAndServerPart := strings.Split(entityID, "@")
|
||||
if len(localAndServerPart) == 3 {
|
||||
localAndServerPart = localAndServerPart[1:]
|
||||
}
|
||||
eID.LocalPart = localAndServerPart[0]
|
||||
eID.ServerPart = localAndServerPart[1]
|
||||
|
||||
return eID
|
||||
}
|
||||
|
||||
func (eID *EntityID) String() string {
|
||||
return fmt.Sprintf("%s%s@%s", eID.EntityIDType, eID.LocalPart, eID.ServerPart)
|
||||
}
|
7
internal/models/protocol_error.go
Normal file
7
internal/models/protocol_error.go
Normal file
@ -0,0 +1,7 @@
|
||||
package models
|
||||
|
||||
type ProtocolError struct {
|
||||
ErrCode string `json:"errCode"`
|
||||
ErrText string `json:"errText"`
|
||||
ErrPayload map[string]interface{} `json:"errPayload"`
|
||||
}
|
86
internal/module_manager.go
Normal file
86
internal/module_manager.go
Normal file
@ -0,0 +1,86 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ChronosX88/zirconium/internal/models"
|
||||
)
|
||||
|
||||
type C2SMessageHandler struct {
|
||||
HandlerFunc func(origin *OriginC2S, message models.BaseMessage)
|
||||
AnonymousAllowed bool
|
||||
}
|
||||
|
||||
type ModuleManager struct {
|
||||
moduleMutex sync.Mutex
|
||||
c2sMessageHandlers map[string][]*C2SMessageHandler
|
||||
internalEventHandlers map[string][]func(sourceModuleName string, event map[string]interface{})
|
||||
}
|
||||
|
||||
func NewModuleManager() (*ModuleManager, error) {
|
||||
var mm = &ModuleManager{
|
||||
c2sMessageHandlers: make(map[string][]*C2SMessageHandler),
|
||||
internalEventHandlers: make(map[string][]func(sourceModuleName string, event map[string]interface{})),
|
||||
}
|
||||
return mm, nil
|
||||
}
|
||||
|
||||
func (mm *ModuleManager) Hook(messageType string, anonymousAllowed bool, handlerFunc func(origin *OriginC2S, message models.BaseMessage)) {
|
||||
mm.moduleMutex.Lock()
|
||||
mm.c2sMessageHandlers[messageType] = append(mm.c2sMessageHandlers[messageType], &C2SMessageHandler{
|
||||
HandlerFunc: handlerFunc,
|
||||
AnonymousAllowed: anonymousAllowed,
|
||||
})
|
||||
mm.moduleMutex.Unlock()
|
||||
}
|
||||
|
||||
func (mm *ModuleManager) HookInternalEvent(eventName string, handlerFunc func(sourceModuleName string, event map[string]interface{})) {
|
||||
mm.moduleMutex.Lock()
|
||||
mm.internalEventHandlers[eventName] = append(mm.internalEventHandlers[eventName], handlerFunc)
|
||||
mm.moduleMutex.Unlock()
|
||||
}
|
||||
|
||||
func (mm *ModuleManager) Unhook(messageType string, handlerFunc func(origin *OriginC2S, message models.BaseMessage)) {
|
||||
mm.moduleMutex.Lock()
|
||||
defer mm.moduleMutex.Unlock()
|
||||
var handlers = mm.c2sMessageHandlers[messageType]
|
||||
if handlers != nil {
|
||||
for i, v := range handlers {
|
||||
if reflect.ValueOf(v.HandlerFunc).Pointer() == reflect.ValueOf(handlerFunc).Pointer() {
|
||||
handlers[i] = handlers[len(handlers)-1]
|
||||
handlers[len(handlers)-1] = nil
|
||||
handlers = handlers[:len(handlers)-1]
|
||||
mm.c2sMessageHandlers[messageType] = handlers
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mm *ModuleManager) UnhookInternalEvent(eventName string, handlerFunc func(sourceModuleName string, event map[string]interface{})) {
|
||||
mm.moduleMutex.Lock()
|
||||
defer mm.moduleMutex.Unlock()
|
||||
var handlers = mm.internalEventHandlers[eventName]
|
||||
if handlers != nil {
|
||||
for i, v := range handlers {
|
||||
if reflect.ValueOf(v).Pointer() == reflect.ValueOf(handlerFunc).Pointer() {
|
||||
handlers[i] = handlers[len(handlers)-1]
|
||||
handlers[len(handlers)-1] = nil
|
||||
handlers = handlers[:len(handlers)-1]
|
||||
mm.internalEventHandlers[eventName] = handlers
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mm *ModuleManager) FireEvent(sourceModuleName string, eventName string, eventPayload map[string]interface{}) {
|
||||
router.RouteInternalEvent(sourceModuleName, eventName, eventPayload)
|
||||
}
|
||||
|
||||
func (mm *ModuleManager) GenerateToken(entityID, deviceID string, tokenExpireTimeDuration time.Duration) (string, error) {
|
||||
token, err := authManager.CreateNewToken(entityID, deviceID, tokenExpireTimeDuration)
|
||||
return token, err
|
||||
}
|
17
internal/origin_c2s.go
Normal file
17
internal/origin_c2s.go
Normal file
@ -0,0 +1,17 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/ChronosX88/zirconium/internal/models"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type OriginC2S struct {
|
||||
wsConn *websocket.Conn
|
||||
connID string
|
||||
entityID *models.EntityID
|
||||
deviceID *string
|
||||
}
|
||||
|
||||
func (o *OriginC2S) Send(message models.BaseMessage) error {
|
||||
return o.wsConn.WriteJSON(message)
|
||||
}
|
69
internal/router.go
Normal file
69
internal/router.go
Normal file
@ -0,0 +1,69 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/ChronosX88/zirconium/internal/models"
|
||||
"github.com/google/logger"
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
moduleManager *ModuleManager
|
||||
connections []*OriginC2S
|
||||
}
|
||||
|
||||
func NewRouter() (*Router, error) {
|
||||
mm, err := NewModuleManager()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &Router{
|
||||
moduleManager: mm,
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Router) RouteMessage(origin *OriginC2S, message models.BaseMessage) {
|
||||
handlers := r.moduleManager.c2sMessageHandlers[message.MessageType]
|
||||
if handlers != nil {
|
||||
for _, v := range handlers {
|
||||
if !v.AnonymousAllowed {
|
||||
var entityID, deviceID string
|
||||
var isValid bool
|
||||
var err error
|
||||
if message.AuthToken != "" {
|
||||
isValid, entityID, deviceID, err = authManager.ValidateToken(message.AuthToken)
|
||||
if err != nil || !isValid {
|
||||
logger.Warningf("Connection %s isn't authorized", origin.connID)
|
||||
msg := PrepareMessageUnauthorized(message)
|
||||
origin.Send(msg)
|
||||
}
|
||||
} else {
|
||||
logger.Warningf("Connection %s isn't authorized", origin.connID)
|
||||
|
||||
msg := PrepareMessageUnauthorized(message)
|
||||
origin.Send(msg)
|
||||
}
|
||||
|
||||
if origin.entityID == nil {
|
||||
origin.entityID = models.NewEntityID(entityID)
|
||||
}
|
||||
if origin.deviceID == nil {
|
||||
origin.deviceID = &deviceID
|
||||
}
|
||||
}
|
||||
go v.HandlerFunc(origin, message)
|
||||
}
|
||||
} else {
|
||||
logger.Infof("Drop message with type %s because server hasn't proper handlers", message.MessageType)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) RouteInternalEvent(sourceModuleName string, eventName string, eventPayload map[string]interface{}) {
|
||||
handlers := r.moduleManager.internalEventHandlers[eventName]
|
||||
if handlers != nil {
|
||||
for _, v := range handlers {
|
||||
go v(sourceModuleName, eventPayload)
|
||||
}
|
||||
} else {
|
||||
logger.Infof("Drop event %s because server hasn't proper handlers", eventName)
|
||||
}
|
||||
}
|
50
internal/utils.go
Normal file
50
internal/utils.go
Normal file
@ -0,0 +1,50 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"reflect"
|
||||
|
||||
"github.com/ChronosX88/zirconium/internal/models"
|
||||
)
|
||||
|
||||
func GenRandomBytes(size int) (blk []byte, err error) {
|
||||
blk = make([]byte, size)
|
||||
_, err = rand.Read(blk)
|
||||
return
|
||||
}
|
||||
|
||||
func StructToMap(item interface{}) map[string]interface{} {
|
||||
res := map[string]interface{}{}
|
||||
if item == nil {
|
||||
return res
|
||||
}
|
||||
v := reflect.TypeOf(item)
|
||||
reflectValue := reflect.ValueOf(item)
|
||||
reflectValue = reflect.Indirect(reflectValue)
|
||||
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
tag := v.Field(i).Tag.Get("json")
|
||||
field := reflectValue.Field(i).Interface()
|
||||
if tag != "" && tag != "-" {
|
||||
if v.Field(i).Type.Kind() == reflect.Struct {
|
||||
res[tag] = StructToMap(field)
|
||||
} else {
|
||||
res[tag] = field
|
||||
}
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func PrepareMessageUnauthorized(msg models.BaseMessage) models.BaseMessage {
|
||||
protocolError := models.ProtocolError{
|
||||
ErrCode: "unauthorized",
|
||||
ErrText: "Unauthorized access",
|
||||
ErrPayload: make(map[string]interface{}),
|
||||
}
|
||||
errMsg := models.NewBaseMessage(msg.MessageType, serverDomain, msg.From, false, StructToMap(protocolError))
|
||||
return errMsg
|
||||
}
|
0
plugins/.gitkeep
Normal file
0
plugins/.gitkeep
Normal file
5
shared/module.go
Normal file
5
shared/module.go
Normal file
@ -0,0 +1,5 @@
|
||||
package shared
|
||||
|
||||
type Module interface {
|
||||
Initialize()
|
||||
}
|
5
shared/plugin_map.go
Normal file
5
shared/plugin_map.go
Normal file
@ -0,0 +1,5 @@
|
||||
package shared
|
||||
|
||||
var Plugins = map[string]Module{
|
||||
// Add plugins here
|
||||
}
|
Loading…
Reference in New Issue
Block a user