Implement basic module system, auth manager and connection handler

This commit is contained in:
ChronosX88 2020-02-10 13:23:10 +04:00
parent 67fbfee9e3
commit a5e6fa72ff
17 changed files with 481 additions and 27 deletions

View File

@ -4,13 +4,12 @@ import (
"log"
"net/http"
"github.com/google/uuid"
"github.com/ChronosX88/zirconium/internal"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var clients = make(map[*websocket.Conn]string)
var clientsReverse = make(map[string]*websocket.Conn)
var connectionHandler = internal.NewConnectionHandler()
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
@ -33,27 +32,5 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
log.Fatal(err)
}
// register client
clients[ws] = uuid.New().String()
clientsReverse[clients[ws]] = ws
go readLoop(ws)
log.Printf("Connection %s created!", clients[ws])
}
func readLoop(c *websocket.Conn) {
for {
if _, _, err := c.NextReader(); err != nil {
connectionID := clients[c]
if connectionID != "" {
delete(clients, c)
delete(clientsReverse, connectionID)
log.Printf("Connection %s closed!", connectionID)
} else {
log.Println("connection wasn't found")
}
c.Close()
break
}
}
connectionHandler.HandleNewConnection(ws)
}

2
go.mod
View File

@ -3,6 +3,8 @@ module github.com/ChronosX88/zirconium
go 1.13
require (
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/google/logger v1.0.1
github.com/google/uuid v1.1.1
github.com/gorilla/mux v1.7.3
github.com/gorilla/websocket v1.4.1

6
go.sum
View File

@ -1,6 +1,12 @@
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/google/logger v1.0.1 h1:Jtq7/44yDwUXMaLTYgXFC31zpm6Oku7OI/k4//yVANQ=
github.com/google/logger v1.0.1/go.mod h1:w7O8nrRr0xufejBlQMI83MXqRusvREoJdaAxV+CoAB4=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

79
internal/auth_manager.go Normal file
View File

@ -0,0 +1,79 @@
package internal
import (
"encoding/base64"
"fmt"
"time"
"github.com/dgrijalva/jwt-go"
)
const (
SigningKeyBytesAmount = 4096
TokenExpireTimeDuration = 24 * time.Hour
)
type AuthManager struct {
signingKey string // For now it is random bytes string represented in Base64
}
type JWTCustomClaims struct {
EntityID string `json:"entityID"`
DeviceID string `json:"deviceID"`
jwt.StandardClaims
}
func NewAuthManager() (*AuthManager, error) {
am := &AuthManager{}
bytes, err := GenRandomBytes(SigningKeyBytesAmount)
if err != nil {
return nil, err
}
am.signingKey = base64.RawStdEncoding.EncodeToString(bytes)
return am, nil
}
func (am *AuthManager) CreateNewToken(entityID, deviceID string, tokenExpireTimeDuration time.Duration) (string, error) {
timeNow := time.Now()
expiringTime := timeNow.Add(tokenExpireTimeDuration)
claims := JWTCustomClaims{
entityID,
deviceID,
jwt.StandardClaims{
ExpiresAt: time.Date(
expiringTime.Year(),
expiringTime.Month(),
expiringTime.Day(),
expiringTime.Hour(),
expiringTime.Minute(),
expiringTime.Second(),
expiringTime.Nanosecond(),
time.UTC,
).Unix(),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(am.signingKey)
if err != nil {
return "", err
}
return tokenString, nil
}
func (am *AuthManager) ValidateToken(tokenString string) (bool, string, string, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
return []byte(am.signingKey), nil
})
if claims, ok := token.Claims.(*JWTCustomClaims); ok && token.Valid {
return true, claims.EntityID, claims.DeviceID, nil
}
return false, "", "", err
}

View File

@ -0,0 +1,39 @@
package internal
import (
"github.com/ChronosX88/zirconium/internal/models"
"github.com/google/logger"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type ConnectionHandler struct {
connections map[string]*OriginC2S
}
func NewConnectionHandler() *ConnectionHandler {
return &ConnectionHandler{}
}
func (ch *ConnectionHandler) HandleNewConnection(wsocket *websocket.Conn) {
uuid, _ := uuid.NewRandom()
uuidStr := uuid.String()
o := &OriginC2S{
wsConn: wsocket,
connID: uuidStr,
}
ch.connections[o.connID] = o
go func() {
for {
var msg models.BaseMessage
err := o.wsConn.ReadJSON(&msg)
if err != nil {
delete(ch.connections, o.connID)
o.wsConn.Close()
logger.Infof("Connection %s was closed. (Reason: %s)", o.connID, err.Error())
break
}
router.RouteMessage(o, msg)
}
}()
}

36
internal/globals.go Normal file
View File

@ -0,0 +1,36 @@
package internal
import (
"github.com/ChronosX88/zirconium/shared"
"github.com/google/logger"
)
var (
ModuleMgr *ModuleManager
router *Router
authManager *AuthManager
serverDomain string
)
func InitializeContext(sDomain string) {
var err error
ModuleMgr, err = NewModuleManager()
if err != nil {
logger.Fatalf("Unable to initialize module manager: %s", err.Error())
}
router, err = NewRouter()
if err != nil {
logger.Fatalf("Unable to initialize router: %s", err.Error())
}
authManager, err = NewAuthManager()
if err != nil {
logger.Fatalf("Unable to initialize authentication manager: %s", err.Error())
}
serverDomain = sDomain
for _, v := range shared.Plugins {
go v.Initialize() // Initialize provided plugins
}
}

View File

@ -1 +0,0 @@
package internal

View File

@ -0,0 +1,27 @@
package models
import "github.com/google/uuid"
// BaseMessage is a basic message model, basis of the whole protocol. It is used for a very easy protocol extension process.
type BaseMessage struct {
ID string `json:"id"`
MessageType string `json:"type"`
From string `json:"from"`
To string `json:"to"`
Ok bool `json:"ok"`
AuthToken string `json:"authToken"`
Payload map[string]interface{} `json:"payload"`
}
func NewBaseMessage(messageType string, from string, to string, ok bool, payload map[string]interface{}) BaseMessage {
uuid, _ := uuid.NewRandom()
uuidStr := uuid.String()
return BaseMessage{
ID: uuidStr,
MessageType: messageType,
From: from,
To: to,
Ok: ok,
Payload: payload,
}
}

View File

@ -0,0 +1,50 @@
package models
import (
"fmt"
"strings"
)
type EntityIDType string
const (
UsernameType EntityIDType = "@"
RoomAliasType EntityIDType = "#"
RoomIDType EntityIDType = "!"
)
type EntityID struct {
EntityIDType EntityIDType
LocalPart string
ServerPart string
}
func NewEntityID(entityID string) *EntityID {
eID := &EntityID{}
switch EntityIDType(string(entityID[0])) {
case UsernameType:
{
eID.EntityIDType = UsernameType
}
case RoomAliasType:
{
eID.EntityIDType = RoomAliasType
}
case RoomIDType:
{
eID.EntityIDType = RoomIDType
}
}
localAndServerPart := strings.Split(entityID, "@")
if len(localAndServerPart) == 3 {
localAndServerPart = localAndServerPart[1:]
}
eID.LocalPart = localAndServerPart[0]
eID.ServerPart = localAndServerPart[1]
return eID
}
func (eID *EntityID) String() string {
return fmt.Sprintf("%s%s@%s", eID.EntityIDType, eID.LocalPart, eID.ServerPart)
}

View File

@ -0,0 +1,7 @@
package models
type ProtocolError struct {
ErrCode string `json:"errCode"`
ErrText string `json:"errText"`
ErrPayload map[string]interface{} `json:"errPayload"`
}

View File

@ -0,0 +1,86 @@
package internal
import (
"reflect"
"sync"
"time"
"github.com/ChronosX88/zirconium/internal/models"
)
type C2SMessageHandler struct {
HandlerFunc func(origin *OriginC2S, message models.BaseMessage)
AnonymousAllowed bool
}
type ModuleManager struct {
moduleMutex sync.Mutex
c2sMessageHandlers map[string][]*C2SMessageHandler
internalEventHandlers map[string][]func(sourceModuleName string, event map[string]interface{})
}
func NewModuleManager() (*ModuleManager, error) {
var mm = &ModuleManager{
c2sMessageHandlers: make(map[string][]*C2SMessageHandler),
internalEventHandlers: make(map[string][]func(sourceModuleName string, event map[string]interface{})),
}
return mm, nil
}
func (mm *ModuleManager) Hook(messageType string, anonymousAllowed bool, handlerFunc func(origin *OriginC2S, message models.BaseMessage)) {
mm.moduleMutex.Lock()
mm.c2sMessageHandlers[messageType] = append(mm.c2sMessageHandlers[messageType], &C2SMessageHandler{
HandlerFunc: handlerFunc,
AnonymousAllowed: anonymousAllowed,
})
mm.moduleMutex.Unlock()
}
func (mm *ModuleManager) HookInternalEvent(eventName string, handlerFunc func(sourceModuleName string, event map[string]interface{})) {
mm.moduleMutex.Lock()
mm.internalEventHandlers[eventName] = append(mm.internalEventHandlers[eventName], handlerFunc)
mm.moduleMutex.Unlock()
}
func (mm *ModuleManager) Unhook(messageType string, handlerFunc func(origin *OriginC2S, message models.BaseMessage)) {
mm.moduleMutex.Lock()
defer mm.moduleMutex.Unlock()
var handlers = mm.c2sMessageHandlers[messageType]
if handlers != nil {
for i, v := range handlers {
if reflect.ValueOf(v.HandlerFunc).Pointer() == reflect.ValueOf(handlerFunc).Pointer() {
handlers[i] = handlers[len(handlers)-1]
handlers[len(handlers)-1] = nil
handlers = handlers[:len(handlers)-1]
mm.c2sMessageHandlers[messageType] = handlers
break
}
}
}
}
func (mm *ModuleManager) UnhookInternalEvent(eventName string, handlerFunc func(sourceModuleName string, event map[string]interface{})) {
mm.moduleMutex.Lock()
defer mm.moduleMutex.Unlock()
var handlers = mm.internalEventHandlers[eventName]
if handlers != nil {
for i, v := range handlers {
if reflect.ValueOf(v).Pointer() == reflect.ValueOf(handlerFunc).Pointer() {
handlers[i] = handlers[len(handlers)-1]
handlers[len(handlers)-1] = nil
handlers = handlers[:len(handlers)-1]
mm.internalEventHandlers[eventName] = handlers
break
}
}
}
}
func (mm *ModuleManager) FireEvent(sourceModuleName string, eventName string, eventPayload map[string]interface{}) {
router.RouteInternalEvent(sourceModuleName, eventName, eventPayload)
}
func (mm *ModuleManager) GenerateToken(entityID, deviceID string, tokenExpireTimeDuration time.Duration) (string, error) {
token, err := authManager.CreateNewToken(entityID, deviceID, tokenExpireTimeDuration)
return token, err
}

17
internal/origin_c2s.go Normal file
View File

@ -0,0 +1,17 @@
package internal
import (
"github.com/ChronosX88/zirconium/internal/models"
"github.com/gorilla/websocket"
)
type OriginC2S struct {
wsConn *websocket.Conn
connID string
entityID *models.EntityID
deviceID *string
}
func (o *OriginC2S) Send(message models.BaseMessage) error {
return o.wsConn.WriteJSON(message)
}

69
internal/router.go Normal file
View File

@ -0,0 +1,69 @@
package internal
import (
"github.com/ChronosX88/zirconium/internal/models"
"github.com/google/logger"
)
type Router struct {
moduleManager *ModuleManager
connections []*OriginC2S
}
func NewRouter() (*Router, error) {
mm, err := NewModuleManager()
if err != nil {
return nil, err
}
r := &Router{
moduleManager: mm,
}
return r, nil
}
func (r *Router) RouteMessage(origin *OriginC2S, message models.BaseMessage) {
handlers := r.moduleManager.c2sMessageHandlers[message.MessageType]
if handlers != nil {
for _, v := range handlers {
if !v.AnonymousAllowed {
var entityID, deviceID string
var isValid bool
var err error
if message.AuthToken != "" {
isValid, entityID, deviceID, err = authManager.ValidateToken(message.AuthToken)
if err != nil || !isValid {
logger.Warningf("Connection %s isn't authorized", origin.connID)
msg := PrepareMessageUnauthorized(message)
origin.Send(msg)
}
} else {
logger.Warningf("Connection %s isn't authorized", origin.connID)
msg := PrepareMessageUnauthorized(message)
origin.Send(msg)
}
if origin.entityID == nil {
origin.entityID = models.NewEntityID(entityID)
}
if origin.deviceID == nil {
origin.deviceID = &deviceID
}
}
go v.HandlerFunc(origin, message)
}
} else {
logger.Infof("Drop message with type %s because server hasn't proper handlers", message.MessageType)
}
}
func (r *Router) RouteInternalEvent(sourceModuleName string, eventName string, eventPayload map[string]interface{}) {
handlers := r.moduleManager.internalEventHandlers[eventName]
if handlers != nil {
for _, v := range handlers {
go v(sourceModuleName, eventPayload)
}
} else {
logger.Infof("Drop event %s because server hasn't proper handlers", eventName)
}
}

50
internal/utils.go Normal file
View File

@ -0,0 +1,50 @@
package internal
import (
"crypto/rand"
"reflect"
"github.com/ChronosX88/zirconium/internal/models"
)
func GenRandomBytes(size int) (blk []byte, err error) {
blk = make([]byte, size)
_, err = rand.Read(blk)
return
}
func StructToMap(item interface{}) map[string]interface{} {
res := map[string]interface{}{}
if item == nil {
return res
}
v := reflect.TypeOf(item)
reflectValue := reflect.ValueOf(item)
reflectValue = reflect.Indirect(reflectValue)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
for i := 0; i < v.NumField(); i++ {
tag := v.Field(i).Tag.Get("json")
field := reflectValue.Field(i).Interface()
if tag != "" && tag != "-" {
if v.Field(i).Type.Kind() == reflect.Struct {
res[tag] = StructToMap(field)
} else {
res[tag] = field
}
}
}
return res
}
func PrepareMessageUnauthorized(msg models.BaseMessage) models.BaseMessage {
protocolError := models.ProtocolError{
ErrCode: "unauthorized",
ErrText: "Unauthorized access",
ErrPayload: make(map[string]interface{}),
}
errMsg := models.NewBaseMessage(msg.MessageType, serverDomain, msg.From, false, StructToMap(protocolError))
return errMsg
}

0
plugins/.gitkeep Normal file
View File

5
shared/module.go Normal file
View File

@ -0,0 +1,5 @@
package shared
type Module interface {
Initialize()
}

5
shared/plugin_map.go Normal file
View File

@ -0,0 +1,5 @@
package shared
var Plugins = map[string]Module{
// Add plugins here
}