Fully refactor authorization architecture (make it more modular)

This commit is contained in:
ChronosX88 2021-03-08 16:06:12 +03:00
parent 8b49ea7cd3
commit 898d603b8d
Signed by: ChronosXYZ
GPG Key ID: 085A69A82C8C511A
21 changed files with 203 additions and 158 deletions

View File

@ -13,7 +13,7 @@ namespace Zirconium.Core
public Router Router { get; }
public PluginManager PluginManager { get; }
public IPluginHostAPI PluginHostAPI { get; }
public AuthManager AuthManager { get; }
public AuthProviderManager AuthProviderManager { get; }
private WebSocketServer _websocketServer;
public DatabaseConnector Database { get; private set; }
@ -25,11 +25,11 @@ namespace Zirconium.Core
SessionManager = new SessionManager();
Router = new Router(this);
PluginHostAPI = new PluginHostAPI(this, Router);
AuthManager = new AuthManager(this);
AuthProviderManager = new AuthProviderManager(this);
Database = new DatabaseConnector(this);
PluginManager = new PluginManager(PluginHostAPI);
PluginManager.LoadPlugins(config.PluginsDirPath, config.EnabledPlugins);
AuthManager.SetDefaultAuthProvider();
AuthProviderManager.SetDefaultAuthProvider();
Log.Info("Zirconium is initialized successfully");
}

View File

@ -1,50 +0,0 @@
using System.Collections.Generic;
using System;
using Zirconium.Core.Plugins.Interfaces;
using System.Linq;
namespace Zirconium.Core
{
public class AuthManager
{
private App _app;
private string _secretString;
private const long DEFAULT_TOKEN_EXPIRATION_TIME_HOURS = 24 * 3600000;
private IList<IAuthProvider> _authProviders;
public IAuthProvider DefaultAuthProvider { get; private set; }
public AuthManager(App app)
{
_app = app;
_authProviders = new List<IAuthProvider>();
DefaultAuthProvider = null;
_secretString = Guid.NewGuid().ToString();
}
public string CreateToken(string entityID, string deviceID, long tokenExpirationMillis)
{
if (DefaultAuthProvider == null) throw new Exception("Default auth provider isn't specified");
return DefaultAuthProvider.CreateAuthToken(entityID, deviceID, tokenExpirationMillis);
}
public string CreateToken(string entityID, string deviceID)
{
return CreateToken(entityID, deviceID, DEFAULT_TOKEN_EXPIRATION_TIME_HOURS);
}
public SessionAuthData ValidateToken(string token)
{
if (DefaultAuthProvider == null) throw new Exception("Default auth provider isn't specified");
return DefaultAuthProvider.TestToken(token);
}
public void AddAuthProvider(IAuthProvider provider)
{
_authProviders.Add(provider);
}
public void SetDefaultAuthProvider() {
DefaultAuthProvider = _authProviders.Where(x => x.GetAuthProviderName() == _app.Config.AuthenticationProvider).FirstOrDefault();
}
}
}

View File

@ -0,0 +1,30 @@
using System.Collections.Generic;
using System;
using Zirconium.Core.Plugins.Interfaces;
using System.Linq;
namespace Zirconium.Core
{
public class AuthProviderManager
{
private App _app;
private IList<IAuthProvider> _authProviders;
public IAuthProvider DefaultAuthProvider { get; private set; }
public AuthProviderManager(App app)
{
_app = app;
_authProviders = new List<IAuthProvider>();
DefaultAuthProvider = null;
}
public void AddAuthProvider(IAuthProvider provider)
{
_authProviders.Add(provider);
}
public void SetDefaultAuthProvider() {
DefaultAuthProvider = _authProviders.FirstOrDefault(x => x.GetAuthProviderName() == _app.Config.AuthenticationProvider);
}
}
}

View File

@ -23,7 +23,6 @@ namespace Zirconium.Core
_app.SessionManager.DeleteSession(ID);
Log.Info($"Connection {ID} was closed (reason: {e.Reason})");
// TODO implement closing connection
}
protected override void OnError(ErrorEventArgs e)
@ -45,7 +44,7 @@ namespace Zirconium.Core
var errMsg = OtherUtils.GenerateProtocolError(
null,
"parseError",
$"Server cannot parse this message yet because it is not JSON",
$"Server cannot parse this message because it is not JSON",
new Dictionary<string, object>()
);
errMsg.From = _app.Config.ServerID;
@ -73,8 +72,8 @@ namespace Zirconium.Core
var session = new Session();
session.ClientAddress = ip;
session.ConnectionHandler = this;
_app.SessionManager.AddSession(this.ID, session);
Log.Info($"Connection {this.ID} was created");
_app.SessionManager.AddSession(ID, session);
Log.Info($"Connection {ID} was created");
}
public void SendMessage(string message)

View File

@ -0,0 +1,14 @@
using System.Collections.Generic;
using Newtonsoft.Json;
namespace Zirconium.Core.Models.Authorization
{
public class AuthorizationRequest
{
[JsonProperty("type")]
public string Type { get; set; }
[JsonProperty("fields")]
public IDictionary<string, dynamic> Fields { get; set; }
}
}

View File

@ -0,0 +1,13 @@
using Newtonsoft.Json;
namespace Zirconium.Core.Models.Authorization
{
public class AuthorizationResponse
{
[JsonProperty("token")]
public string Token;
[JsonProperty("deviceID")]
public string DeviceID { get; set; }
}
}

View File

@ -22,10 +22,7 @@ namespace Zirconium.Core.Models
[JsonProperty("ok")]
public bool Ok { get; set; }
[JsonProperty("authToken", NullValueHandling = NullValueHandling.Ignore)]
public string AuthToken { get; set; }
[JsonProperty("payload")]
[JsonProperty("payload", NullValueHandling = NullValueHandling.Ignore)]
public IDictionary<string, object> Payload { get; set; }
public BaseMessage() {
@ -33,6 +30,11 @@ namespace Zirconium.Core.Models
ID = Guid.NewGuid().ToString();
}
public BaseMessage(string type) : this()
{
MessageType = type;
}
public BaseMessage(BaseMessage message, bool reply) : this()
{
if (message != null)
@ -43,13 +45,12 @@ namespace Zirconium.Core.Models
{
// TODO probably need to fix it
From = message.To.First();
To = new string[] { message.From };
To = new[] { message.From };
}
else
{
From = message.From;
To = message.To;
AuthToken = message.AuthToken;
}
Ok = message.Ok;

View File

@ -4,7 +4,7 @@ namespace Zirconium.Core.Models
{
public class Session
{
public SessionAuthData LastTokenPayload { get; set; }
public SessionAuthData AuthData { get; set; }
public IPAddress ClientAddress { get; set; }
public ConnectionHandler ConnectionHandler { get; set; }
}

View File

@ -1,17 +1,18 @@
using System.Collections.Generic;
using Zirconium.Core.Models;
using Zirconium.Core.Models.Authorization;
namespace Zirconium.Core.Plugins.Interfaces
{
public interface IAuthProvider
{
// Method for checking validity of access token in each message
SessionAuthData TestToken(string token);
(SessionAuthData, AuthorizationResponse) TestAuthFields(IDictionary<string, dynamic> fields);
// Method for testing password when logging in
bool TestPassword(string username, string pass);
EntityID GetEntityID(IDictionary<string, dynamic> fields);
// User registration logic
void CreateUser(string username, string pass);
string CreateAuthToken(string entityID, string deviceID, long tokenExpirationMillis);
string GetAuthProviderName();
string[] GetAuthSupportedMethods();
}
}

View File

@ -11,8 +11,6 @@ namespace Zirconium.Core.Plugins.Interfaces
void Unhook(IC2SMessageHandler handler);
void UnhookCoreEvent(ICoreEventHandler handler);
void FireEvent(CoreEvent coreEvent);
string GenerateAuthToken(string entityID, string deviceID, int tokenExpirationMillis);
string GenerateAuthToken(string entityID, string deviceID);
string[] GetServerDomains();
string GetServerID();
void SendMessage(Session session, BaseMessage message);

View File

@ -1,4 +1,3 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using MongoDB.Driver;
@ -27,7 +26,7 @@ namespace Zirconium.Core.Plugins
}
public void ProvideAuth(IAuthProvider provider) {
_app.AuthManager.AddAuthProvider(provider);
_app.AuthProviderManager.AddAuthProvider(provider);
}
public void FireEvent(CoreEvent coreEvent)
@ -35,11 +34,6 @@ namespace Zirconium.Core.Plugins
_router.RouteCoreEvent(coreEvent);
}
public string GenerateAuthToken(string entityID, string deviceID, int tokenExpirationMillis)
{
return _app.AuthManager.CreateToken(entityID, deviceID, tokenExpirationMillis);
}
public string[] GetServerDomains()
{
return _app.Config.ServerDomains;
@ -105,12 +99,7 @@ namespace Zirconium.Core.Plugins
public IAuthProvider GetAuthProvider()
{
return _app.AuthManager.DefaultAuthProvider;
}
public string GenerateAuthToken(string entityID, string deviceID)
{
return _app.AuthManager.CreateToken(entityID, deviceID);
return _app.AuthProviderManager.DefaultAuthProvider;
}
}
}

View File

@ -7,6 +7,7 @@ using Log4Sharp;
using McMaster.NETCore.Plugins;
using MongoDB.Driver;
using Zirconium.Core.Models;
using Zirconium.Core.Models.Authorization;
using Zirconium.Core.Plugins.Interfaces;
using Zirconium.Core.Plugins.IPC;
using Zirconium.Utils;
@ -76,7 +77,9 @@ namespace Zirconium.Core.Plugins
typeof(BaseMessage),
typeof(CoreEvent),
typeof(ExportedIPCMethodAttribute),
typeof(IMongoDatabase)
typeof(IMongoDatabase),
typeof(AuthorizationRequest),
typeof(AuthorizationResponse)
},
config => config.PreferSharedTypes = true
);

View File

@ -36,45 +36,23 @@ namespace Zirconium.Core
new Dictionary<string, object>()
);
msg.From = _app.Config.ServerID;
var serializedMsg = JsonConvert.SerializeObject(msg);
session.ConnectionHandler.SendMessage(serializedMsg);
session.ConnectionHandler.SendMessage(msg);
return;
}
var handlerTasks = new List<Task>();
foreach (var h in handlers)
{
if (h.IsAuthorizationRequired())
if (h.IsAuthorizationRequired() && session.AuthData == null)
{
SessionAuthData tokenPayload;
try
{
tokenPayload = _app.AuthManager.ValidateToken(message.AuthToken);
}
catch (Exception e)
{
Log.Warning(e.Message);
var errorMsg = OtherUtils.GenerateProtocolError(
message,
"unauthorized",
"Unauthorized access",
new Dictionary<string, object>()
);
errorMsg.From = _app.Config.ServerID;
var serializedMsg = JsonConvert.SerializeObject(errorMsg);
session.LastTokenPayload = null;
session.ConnectionHandler.SendMessage(serializedMsg);
session.ConnectionHandler.SendMessage(OtherUtils.GenerateUnauthorizedError(message, _app.Config.ServerID));
return;
}
session.LastTokenPayload = tokenPayload;
}
var task = Task.Run(() =>
handlerTasks.Add(Task.Run(() =>
{
// probably need to wrap whole foreach body, not only HandleMessage call - need to investigate
h.HandleMessage(session, message);
});
handlerTasks.Add(task);
}));
}
try
{

View File

@ -81,7 +81,7 @@ namespace Zirconium.Utils
public static TValue GetValueOrDefault<TKey, TValue>(this IDictionary<TKey, TValue> dictionary,
TKey key,
TValue defaultValue)
TValue defaultValue = default)
{
TValue value;
return dictionary.TryGetValue(key, out value) ? value : defaultValue;

View File

@ -6,7 +6,7 @@ namespace Zirconium.Utils
{
public static class OtherUtils
{
public static BaseMessage GenerateProtocolError(BaseMessage parentMessage, string errCode, string errText, IDictionary<string, object> errPayload)
public static BaseMessage GenerateProtocolError(BaseMessage parentMessage, string errCode, string errText, IDictionary<string, object> errPayload = null)
{
ProtocolError err = new ProtocolError();
err.ErrCode = errCode;
@ -21,5 +21,17 @@ namespace Zirconium.Utils
msg.Payload = err.ToDictionary();
return msg;
}
public static BaseMessage GenerateUnauthorizedError(BaseMessage replyTo, string serverID)
{
var msg = GenerateProtocolError(
replyTo,
"unauthorized",
"Unauthorized access",
new Dictionary<string, object>()
);
msg.From = serverID;
return msg;
}
}
}

View File

@ -83,7 +83,7 @@ namespace ChatSubsystem
var recipientSession = pluginHostAPI.GetSessionManager()
.GetSessions()
.Select(x => x.Value)
.Where(x => x.LastTokenPayload.EntityID.Where(x => x == message.To.First()).FirstOrDefault() != null)
.Where(x => x.AuthData.EntityID.Where(x => x == message.To.First()).FirstOrDefault() != null)
.FirstOrDefault();
if (recipientSession == null)
@ -93,7 +93,7 @@ namespace ChatSubsystem
}
var msgForRecipient = new BaseMessage();
msgForRecipient.From = session.LastTokenPayload.EntityID.First();
msgForRecipient.From = session.AuthData.EntityID.First();
msgForRecipient.MessageType = "urn:cadmium:chats:message";
respPayload.Type = receivedMessage.Type;
respPayload.Content = receivedMessage.Content;

View File

@ -0,0 +1,10 @@
using Zirconium.Core.Models.Authorization;
namespace DefaultAuthProvider
{
public class LoginPassFields
{
public string Username { get; set; }
public string Password { get; set; }
}
}

View File

@ -8,6 +8,9 @@ using Newtonsoft.Json;
using JWT.Algorithms;
using JWT.Builder;
using System;
using System.Collections.Generic;
using Zirconium.Core.Models;
using Zirconium.Core.Models.Authorization;
using Zirconium.Utils;
namespace DefaultAuthProvider
@ -23,12 +26,14 @@ namespace DefaultAuthProvider
{
var db = pluginHost.GetRawDatabase();
var jwtSecret = pluginHost.GetSettings(this)["JWTSecret"];
pluginHost.ProvideAuth(new DefaultAuthProvider(db, jwtSecret));
pluginHost.ProvideAuth(new DefaultAuthProvider(db, jwtSecret, pluginHost.GetServerID()));
}
}
public class DefaultAuthProvider : IAuthProvider
{
private const long DEFAULT_TOKEN_EXPIRATION_TIME_HOURS = 24 * 3600000;
public class User
{
[BsonId]
@ -41,11 +46,13 @@ namespace DefaultAuthProvider
private IMongoDatabase _db;
private IMongoCollection<User> usersCol;
private string jwtSecret;
private string _serverID;
public DefaultAuthProvider(IMongoDatabase db, string jwtSecret)
public DefaultAuthProvider(IMongoDatabase db, string jwtSecret, string serverID)
{
this._db = db;
this.jwtSecret = jwtSecret;
this._serverID = serverID;
this.usersCol = db.GetCollection<User>("default_auth_data");
_createUsernameUniqueIndex();
}
@ -59,22 +66,50 @@ namespace DefaultAuthProvider
usersCol.Indexes.CreateOne(createIndexModel);
}
public EntityID GetEntityID(IDictionary<string, dynamic> fieldsDict)
{
var fields = fieldsDict.ToObject<LoginPassFields>();
return new EntityID('@', fields.Username, _serverID);
}
public void CreateUser(string username, string pass)
{
var user = new User();
user.Username = username; // TODO add check on bad chars
var hashed = PasswordHasher.CreatePasswordHash(pass);
System.GC.Collect();
GC.Collect();
user.Password = hashed.Item1;
user.Salt = hashed.Item2;
_db.GetCollection<User>("default_auth_data").InsertOne(user);
}
public string GetAuthProviderName()
public (SessionAuthData, AuthorizationResponse) TestAuthFields(IDictionary<string, dynamic> fieldsDict)
{
return "default";
if (fieldsDict.GetValueOrDefault("token") == null)
{
var fields = fieldsDict.ToObject<LoginPassFields>();
bool valid = TestPassword(fields.Username, fields.Password);
if (valid)
{
// TODO Fix device system
var tokenData = CreateAuthToken(GetEntityID(fieldsDict).ToString(), "ABCDEF", DEFAULT_TOKEN_EXPIRATION_TIME_HOURS);
var res = new AuthorizationResponse()
{
Token = tokenData.Item1
};
return (tokenData.Item2, res);
}
return (null, null);
}
return (TestToken(fieldsDict["token"]), null);
}
public string GetAuthProviderName() => "default";
public string[] GetAuthSupportedMethods() => new[] { "urn:cadmium:auth:login_password", "urn:cadmium:auth:token" };
public bool TestPassword(string username, string pass)
{
var filter = Builders<User>.Filter.Eq("Username", username);
@ -84,32 +119,40 @@ namespace DefaultAuthProvider
return false;
}
var valid = PasswordHasher.VerifyHash(pass, user.Salt, user.Password);
System.GC.Collect();
GC.Collect();
return valid;
}
public SessionAuthData TestToken(string token)
{
try
{
var jsonPayload = new JwtBuilder()
.WithAlgorithm(new HMACSHA256Algorithm()) // symmetric
.WithSecret(this.jwtSecret)
.WithSecret(jwtSecret)
.MustVerifySignature()
.Decode(token);
var payload = JsonConvert.DeserializeObject<SessionAuthData>(jsonPayload);
if (payload == null) payload = new SessionAuthData();
return payload; // TODO add enchanced token validation
}
catch
{
return null;
}
}
public string CreateAuthToken(string entityID, string deviceID, long tokenExpirationMillis)
private (string, SessionAuthData) CreateAuthToken(string entityID, string deviceID, long tokenExpirationMillis)
{
SessionAuthData payload = new SessionAuthData();
payload.DeviceID = deviceID;
payload.EntityID = new string[] { entityID };
return new JwtBuilder()
return (new JwtBuilder()
.WithAlgorithm(new HMACSHA256Algorithm()) // symmetric
.WithSecret(this.jwtSecret)
.AddClaim("exp", DateTimeOffset.UtcNow.AddMilliseconds(tokenExpirationMillis).ToUnixTimeSeconds())
.AddClaims(payload.ToDictionary())
.Encode();
.Encode(), payload);
}
}
}

View File

@ -1,5 +1,7 @@
using System.Collections.Generic;
using System.Linq;
using Zirconium.Core.Models;
using Zirconium.Core.Models.Authorization;
using Zirconium.Core.Plugins.Interfaces;
using Zirconium.Utils;
@ -7,7 +9,8 @@ namespace InBandLogin.Handlers
{
public class LoginC2SHandler : IC2SMessageHandler
{
private const string errID = "invalid_creds";
private const string errID = "urn:cadmium:auth:invalid";
private const string invalidAuthType = "urn:cadmium:auth:invalid_type";
private readonly IPluginHostAPI _pluginHost;
public LoginC2SHandler(IPluginHostAPI pluginHostApi)
@ -22,31 +25,41 @@ namespace InBandLogin.Handlers
public string GetHandlingMessageType()
{
return "profile:login";
return "urn:cadmium:auth";
}
public void HandleMessage(Session session, BaseMessage message)
{
var pObj = message.Payload.ToObject<LoginRequestPayload>();
var pObj = message.Payload.ToObject<AuthorizationRequest>();
var authProvider = _pluginHost.GetAuthProvider();
if (authProvider.TestPassword(pObj.Username, pObj.Password))
if (!authProvider.GetAuthSupportedMethods().Contains(pObj.Type))
{
var reply = OtherUtils.GenerateProtocolError(
message,
invalidAuthType,
"auth type is invalid"
);
session.ConnectionHandler.SendMessage(reply);
return;
}
var authData = authProvider.TestAuthFields(pObj.Fields);
if (authData.Item1 != null)
{
BaseMessage reply = new BaseMessage(message, true);
var p = new LoginResponsePayload();
string deviceID = "ABCDEF"; // TODO fix device id system
p.AuthToken = _pluginHost.GenerateAuthToken($"@{pObj.Username}@{_pluginHost.GetServerID()}", deviceID);
p.DeviceID = deviceID;
reply.Payload = p.ToDictionary();
if (authData.Item2 != null)
{
reply.Payload = authData.Item2.ToDictionary();
}
reply.Ok = true;
session.ConnectionHandler.SendMessage(reply);
session.AuthData = authData.Item1;
}
else
{
var reply = OtherUtils.GenerateProtocolError(
message,
errID,
"Username/password isn't valid",
new Dictionary<string, object>()
"auth credentials isn't valid"
);
session.ConnectionHandler.SendMessage(reply);
}

View File

@ -22,7 +22,7 @@ namespace InBandLogin.Handlers
public string GetHandlingMessageType()
{
return "profile:register";
return "urn:cadmium:register";
}
public void HandleMessage(Session session, BaseMessage message)
@ -62,12 +62,6 @@ namespace InBandLogin.Handlers
BaseMessage reply = new BaseMessage(message, true);
var p = new RegisterResponsePayload();
p.UserID = $"@{pObj.Username}@{_pluginHost.GetServerID()}";
if (pObj.LoginOnSuccess)
{
string deviceID = "ABCDEF"; // TODO fix device id system
p.AuthToken = _pluginHost.GenerateAuthToken($"@{pObj.Username}@{_pluginHost.GetServerID()}", deviceID);
p.DeviceID = deviceID;
}
reply.Payload = p.ToDictionary();
reply.Ok = true;

View File

@ -27,9 +27,6 @@ namespace InBandLogin
[JsonProperty("password")]
public string Password { get; set; }
[JsonProperty("loginOnSuccess")]
public bool LoginOnSuccess { get; set; }
}
class RegisterResponsePayload