diff --git a/internal/backend.go b/internal/backend.go index 2783828..d32c59a 100644 --- a/internal/backend.go +++ b/internal/backend.go @@ -10,7 +10,7 @@ import ( type Backend interface { Register(username, password, device string) (user User, token string, err *models.ApiError) - Login(username, password, device string) (token string, err *models.ApiError) + Login(username, password, device string) (user User, token string, err *models.ApiError) GetUserByToken(token string) (user User) GetRoomByID(id string) Room Sync(token string, request sync.SyncRequest) (response *sync.SyncReply, err *models.ApiError) diff --git a/internal/backends/memory/backend.go b/internal/backends/memory/backend.go index 73a4427..2da8351 100644 --- a/internal/backends/memory/backend.go +++ b/internal/backends/memory/backend.go @@ -28,45 +28,42 @@ func NewBackend(hostname string) *Backend { func (backend *Backend) Register(username, password, device string) (user internal.User, token string, err *models.ApiError) { backend.mutex.Lock() - defer backend.mutex.Unlock() if _, ok := backend.data[username]; ok { + backend.mutex.Unlock() return nil, "", internal.NewError(models.M_USER_IN_USE, "trying to register a user ID which has been taken") } - token = newToken(defaultTokenSize) - user = &User{ name: username, password: password, - Tokens: map[string]Token{ - token: { - Device: device}}, - backend: backend} + Tokens: make(map[string]Token), + backend: backend} backend.data[username] = user - return user, token, nil + backend.mutex.Unlock() + return backend.Login(username, password, device) } -func (backend *Backend) Login(username, password, device string) (token string, err *models.ApiError) { +func (backend *Backend) Login(username, password, device string) (user internal.User, token string, err *models.ApiError) { backend.mutex.Lock() defer backend.mutex.Unlock() user, ok := backend.data[username] if !ok { - return "", internal.NewError(models.M_FORBIDDEN, "wrong username") + return nil, "", internal.NewError(models.M_FORBIDDEN, "wrong username") } if user.Password() != password { - return "", internal.NewError(models.M_FORBIDDEN, "wrong password") + return nil, "", internal.NewError(models.M_FORBIDDEN, "wrong password") } token = newToken(defaultTokenSize) backend.data[username].(*User).Tokens[token] = Token{Device: device} - return token, nil + return user, token, nil } func (backend *Backend) Sync(token string, request mSync.SyncRequest) (response *mSync.SyncReply, err *models.ApiError) { diff --git a/internal/backends/memory/backend_test.go b/internal/backends/memory/backend_test.go index 2def696..cb988b3 100644 --- a/internal/backends/memory/backend_test.go +++ b/internal/backends/memory/backend_test.go @@ -49,7 +49,7 @@ func TestLogin(t *testing.T) { _, _, err := backend.Register(userName, password, "") assert.Nil(t, err) - token, err := backend.Login(userName, password, "") + _, token, err := backend.Login(userName, password, "") assert.Nil(t, err) assert.NotZero(t, token) } @@ -65,10 +65,10 @@ func TestLoginWithWrongCredentials(t *testing.T) { _, _, err := backend.Register(userName, password, "") assert.Nil(t, err) - _, err = backend.Login(userName, "wrong password", "") + _, _, err = backend.Login(userName, "wrong password", "") assert.NotNil(t, err) - _, err = backend.Login("wrong user name", password, "") + _, _, err = backend.Login("wrong user name", password, "") assert.NotNil(t, err) } @@ -83,7 +83,7 @@ func TestLogout(t *testing.T) { user, _, err := backend.Register(userName, password, "") assert.Nil(t, err) - token, err := backend.Login(userName, password, "") + _, token, err := backend.Login(userName, password, "") assert.Nil(t, err) assert.NotZero(t, token) diff --git a/internal/backends/memory/user_test.go b/internal/backends/memory/user_test.go index 8c7e7c5..53e93cf 100644 --- a/internal/backends/memory/user_test.go +++ b/internal/backends/memory/user_test.go @@ -94,7 +94,7 @@ func TestLogoutWithWrongToken(t *testing.T) { user, _, err := backend.Register(userName, password, "") assert.Nil(t, err) - token, err := backend.Login(userName, password, "") + _, token, err := backend.Login(userName, password, "") assert.Nil(t, err) assert.NotZero(t, token) @@ -185,7 +185,7 @@ func TestLogoutAll(t *testing.T) { assert.Nil(t, err) assert.Len(t, user.Devices(), 1) - _, err = backend.Login(userName, password, "dev2") + _, _, err = backend.Login(userName, password, "dev2") assert.Nil(t, err) assert.Len(t, user.Devices(), 2) diff --git a/internal/handlers.go b/internal/handlers.go index e480fc8..2beacaa 100644 --- a/internal/handlers.go +++ b/internal/handlers.go @@ -66,7 +66,7 @@ func LoginHandler(w http.ResponseWriter, r *http.Request) { request.Identifier.User = strings.TrimPrefix(request.Identifier.User, "@") } - token, apiErr := currServer.Backend.Login(request.Identifier.User, request.Password, request.DeviceID) + _, token, apiErr := currServer.Backend.Login(request.Identifier.User, request.Password, request.DeviceID) if apiErr != nil { errorResponse(w, *apiErr, http.StatusForbidden, "") return