558 lines
16 KiB
Go
558 lines
16 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/stretchr/testify/assert"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type mockUserStore struct {
|
|
CreateUserFunc func(context.Context, data.CreateUserParams) (data.User, error)
|
|
ListUsersFunc func(context.Context) ([]data.User, error)
|
|
GetUserByIDFunc func(context.Context, uuid.UUID) (data.User, error)
|
|
GetUserByUsernameFunc func(context.Context, string) (data.User, error)
|
|
UpdatePasswordFunc func(context.Context, data.UpdatePasswordParams) error
|
|
DeleteUserFunc func(context.Context, uuid.UUID) error
|
|
RevokeAllUserRefreshTokensFunc func(context.Context, uuid.UUID) error
|
|
}
|
|
|
|
func (m *mockUserStore) CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
|
return m.CreateUserFunc(ctx, arg)
|
|
}
|
|
|
|
func (m *mockUserStore) ListUsers(ctx context.Context) ([]data.User, error) {
|
|
return m.ListUsersFunc(ctx)
|
|
}
|
|
|
|
func (m *mockUserStore) GetUserByID(ctx context.Context, id uuid.UUID) (data.User, error) {
|
|
return m.GetUserByIDFunc(ctx, id)
|
|
}
|
|
|
|
func (m *mockUserStore) GetUserByUsername(ctx context.Context, username string) (data.User, error) {
|
|
return m.GetUserByUsernameFunc(ctx, username)
|
|
}
|
|
|
|
func (m *mockUserStore) UpdatePassword(ctx context.Context, arg data.UpdatePasswordParams) error {
|
|
return m.UpdatePasswordFunc(ctx, arg)
|
|
}
|
|
|
|
func (m *mockUserStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
|
return m.DeleteUserFunc(ctx, id)
|
|
}
|
|
|
|
func (m *mockUserStore) RevokeAllUserRefreshTokens(ctx context.Context, id uuid.UUID) error {
|
|
return m.RevokeAllUserRefreshTokensFunc(ctx, id)
|
|
}
|
|
|
|
func TestCreateUser_Duplicate(t *testing.T) {
|
|
mockStore := &mockUserStore{
|
|
CreateUserFunc: func(ctx context.Context, arg data.CreateUserParams) (data.User, error) {
|
|
return data.User{}, &pgconn.PgError{Code: "23505"}
|
|
},
|
|
}
|
|
|
|
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
|
|
|
reqBody := `{"username": "existing", "password": "validPass123!"}`
|
|
req := httptest.NewRequest("POST", "/", strings.NewReader(reqBody))
|
|
w := httptest.NewRecorder()
|
|
|
|
rs.Create(w, req)
|
|
|
|
assert.Equal(t, http.StatusConflict, w.Code)
|
|
assert.Contains(t, w.Body.String(), "Username is already in use")
|
|
}
|
|
|
|
func TestCreateUser_InvalidUsername(t *testing.T) {
|
|
mockStore := &mockUserStore{} // No DB calls expected
|
|
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
|
|
|
// Test various invalid usernames
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
}{
|
|
{
|
|
"too short",
|
|
`{"username": "a", "password": "validPass123!"}`,
|
|
},
|
|
{
|
|
"invalid chars",
|
|
`{"username": "user@name", "password": "validPass123!"}`,
|
|
},
|
|
{
|
|
"empty",
|
|
`{"username": "", "password": "validPass123!"}`,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
|
w := httptest.NewRecorder()
|
|
rs.Create(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCreateUser_InvalidPassword(t *testing.T) {
|
|
mockStore := &mockUserStore{}
|
|
rs := usersResource{Users: mockStore, JWTSecret: "test-secret"}
|
|
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
}{
|
|
{
|
|
"too short",
|
|
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1)),
|
|
},
|
|
{
|
|
"too long",
|
|
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1)),
|
|
},
|
|
{
|
|
"low entropy",
|
|
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength)),
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "/", strings.NewReader(tc.body))
|
|
w := httptest.NewRecorder()
|
|
rs.Create(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestListUsers_Success(t *testing.T) {
|
|
testUsers := []data.User{
|
|
{ID: uuid.New(), Username: "user1"},
|
|
{ID: uuid.New(), Username: "user2"},
|
|
}
|
|
|
|
mockStore := &mockUserStore{
|
|
ListUsersFunc: func(ctx context.Context) ([]data.User, error) {
|
|
return testUsers, nil
|
|
},
|
|
}
|
|
|
|
rs := usersResource{Users: mockStore}
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
rs.List(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
var response []map[string]any
|
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.Len(t, response, 2)
|
|
assert.Equal(t, "user1", response[0]["username"])
|
|
assert.NotContains(t, response[0], "password_hash")
|
|
}
|
|
|
|
func TestUpdatePassword_InvalidOldPassword(t *testing.T) {
|
|
// User with password hash that won't match "wrongpassword"
|
|
user := data.User{
|
|
ID: uuid.New(),
|
|
PasswordHash: "$2a$10$PHhno.bZBF8IEINdFRZAPujMxIN65msElATgJG6FIxZdeWYVLSfFi", // Hash of "correctpassword"
|
|
}
|
|
|
|
mockStore := &mockUserStore{}
|
|
rs := usersResource{Users: mockStore}
|
|
|
|
reqBody := `{"old_password": "wrongpassword", "new_password": "NewValidPass321!"}`
|
|
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
|
w := httptest.NewRecorder()
|
|
|
|
rs.UpdatePassword(w, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
}
|
|
|
|
func TestAdminDelete_Success(t *testing.T) {
|
|
user := data.User{ID: uuid.New()}
|
|
deleteCalled := false
|
|
revokeCalled := false
|
|
|
|
mockStore := &mockUserStore{
|
|
DeleteUserFunc: func(ctx context.Context, id uuid.UUID) error {
|
|
deleteCalled = true
|
|
assert.Equal(t, user.ID, id)
|
|
return nil
|
|
},
|
|
RevokeAllUserRefreshTokensFunc: func(ctx context.Context, id uuid.UUID) error {
|
|
revokeCalled = true
|
|
assert.Equal(t, user.ID, id)
|
|
return nil
|
|
},
|
|
}
|
|
|
|
rs := usersResource{Users: mockStore}
|
|
req := httptest.NewRequest("DELETE", "/", nil)
|
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
|
w := httptest.NewRecorder()
|
|
|
|
rs.AdminDelete(w, req)
|
|
|
|
assert.Equal(t, http.StatusNoContent, w.Code)
|
|
assert.True(t, deleteCalled)
|
|
assert.True(t, revokeCalled)
|
|
}
|
|
|
|
func TestOwnerDelete_InvalidCredentials(t *testing.T) {
|
|
// Create user with known password hash
|
|
correctPassword := "CorrectPass123!"
|
|
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost)
|
|
user := data.User{
|
|
ID: uuid.New(),
|
|
PasswordHash: string(hashedPassword),
|
|
}
|
|
|
|
mockStore := &mockUserStore{}
|
|
rs := usersResource{Users: mockStore}
|
|
|
|
reqBody := `{"password": "wrongpassword"}`
|
|
req := httptest.NewRequest("DELETE", "/", strings.NewReader(reqBody))
|
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
|
w := httptest.NewRecorder()
|
|
|
|
rs.OwnerDelete(w, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
}
|
|
|
|
func TestUsersGetCurrentUser(t *testing.T) {
|
|
validUserID := uuid.New()
|
|
testTime := time.Now().UTC().Truncate(time.Second)
|
|
testUser := data.User{
|
|
ID: validUserID,
|
|
Username: "testuser",
|
|
CreatedAt: &testTime,
|
|
UpdatedAt: &testTime,
|
|
IsAdmin: false,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupContext func(context.Context) context.Context
|
|
mockSetup func(*mockUserStore)
|
|
wantStatus int
|
|
wantResponse string
|
|
}{
|
|
{
|
|
name: "success",
|
|
setupContext: func(ctx context.Context) context.Context {
|
|
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: validUserID.String(),
|
|
},
|
|
})
|
|
},
|
|
mockSetup: func(m *mockUserStore) {
|
|
m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) {
|
|
assert.Equal(t, validUserID, id)
|
|
return testUser, nil
|
|
}
|
|
},
|
|
wantStatus: http.StatusOK,
|
|
wantResponse: fmt.Sprintf(
|
|
`{"created_at":"%s","id":"%s","is_admin":false,"updated_at":"%s","username":"testuser"}`,
|
|
testUser.CreatedAt.Format(time.RFC3339Nano),
|
|
validUserID.String(),
|
|
testUser.UpdatedAt.Format(time.RFC3339Nano),
|
|
),
|
|
},
|
|
{
|
|
name: "user not found",
|
|
setupContext: func(ctx context.Context) context.Context {
|
|
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: validUserID.String(),
|
|
},
|
|
})
|
|
},
|
|
mockSetup: func(m *mockUserStore) {
|
|
m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) {
|
|
return data.User{}, errors.New("not found")
|
|
}
|
|
},
|
|
wantStatus: http.StatusNotFound,
|
|
wantResponse: `{"error":"User not found"}`,
|
|
},
|
|
{
|
|
name: "unauthorized",
|
|
setupContext: func(ctx context.Context) context.Context {
|
|
return ctx // No user claims in context
|
|
},
|
|
mockSetup: func(m *mockUserStore) {},
|
|
wantStatus: http.StatusUnauthorized,
|
|
wantResponse: `{"error":"Unauthorized"}`,
|
|
},
|
|
{
|
|
name: "invalid user ID",
|
|
setupContext: func(ctx context.Context) context.Context {
|
|
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: "invalid",
|
|
},
|
|
})
|
|
},
|
|
mockSetup: func(m *mockUserStore) {},
|
|
wantStatus: http.StatusInternalServerError,
|
|
wantResponse: `{"error":"Invalid user ID"}`,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockStore := &mockUserStore{}
|
|
tt.mockSetup(mockStore)
|
|
|
|
rs := usersResource{Users: mockStore}
|
|
req := httptest.NewRequest("GET", "/me", nil)
|
|
req = req.WithContext(tt.setupContext(req.Context()))
|
|
|
|
w := httptest.NewRecorder()
|
|
rs.Get(w, req)
|
|
|
|
assert.Equal(t, tt.wantStatus, w.Code)
|
|
|
|
if tt.wantResponse != "" {
|
|
actual := strings.TrimSpace(w.Body.String())
|
|
assert.JSONEq(t, tt.wantResponse, actual)
|
|
}
|
|
|
|
// Verify sensitive fields are never exposed
|
|
if w.Code == http.StatusOK {
|
|
var response map[string]any
|
|
json.Unmarshal(w.Body.Bytes(), &response)
|
|
_, exists := response["password_hash"]
|
|
assert.False(t, exists, "password_hash should not be exposed")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdatePassword_DatabaseError(t *testing.T) {
|
|
// Add user with a valid password to the context
|
|
oldPassword := "OldValidPass321!"
|
|
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(oldPassword), bcrypt.DefaultCost)
|
|
user := data.User{
|
|
ID: uuid.New(),
|
|
PasswordHash: string(hashedPassword),
|
|
}
|
|
mockStore := &mockUserStore{
|
|
UpdatePasswordFunc: func(ctx context.Context, arg data.UpdatePasswordParams) error {
|
|
return errors.New("database error")
|
|
},
|
|
}
|
|
|
|
rs := usersResource{Users: mockStore}
|
|
|
|
reqBody := fmt.Sprintf(`{"old_password": "%s", "new_password": "NewValidPass123!"}`, oldPassword)
|
|
req := httptest.NewRequest("PUT", "/", strings.NewReader(reqBody))
|
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, user))
|
|
w := httptest.NewRecorder()
|
|
|
|
rs.UpdatePassword(w, req)
|
|
|
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
|
assert.Contains(t, w.Body.String(), "Failed to update password")
|
|
}
|
|
|
|
func TestUsersLogin(t *testing.T) {
|
|
validPassword := "validPass123!"
|
|
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(validPassword), bcrypt.DefaultCost)
|
|
testUser := data.User{
|
|
ID: uuid.New(),
|
|
Username: "test_username",
|
|
PasswordHash: string(hashedPassword),
|
|
}
|
|
jwtSecret := "test-secret"
|
|
|
|
tests := []struct {
|
|
name string
|
|
includeUser string
|
|
wantUserData bool
|
|
requestBody any
|
|
mockSetup func(*mockUserStore)
|
|
wantStatus int
|
|
wantResponse string
|
|
checkCookie bool
|
|
}{
|
|
{
|
|
name: "invalid request body",
|
|
requestBody: "invalid",
|
|
mockSetup: func(m *mockUserStore) {},
|
|
wantStatus: http.StatusBadRequest,
|
|
wantResponse: `{"error":"Invalid request body"}`,
|
|
},
|
|
{
|
|
name: "user not found",
|
|
requestBody: map[string]string{
|
|
"username": "nouser",
|
|
"password": validPassword,
|
|
},
|
|
mockSetup: func(m *mockUserStore) {
|
|
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
return data.User{}, errors.New("not found")
|
|
}
|
|
},
|
|
wantStatus: http.StatusUnauthorized,
|
|
wantResponse: `{"error":"Invalid credentials"}`,
|
|
},
|
|
{
|
|
name: "invalid password",
|
|
requestBody: map[string]string{
|
|
"username": testUser.Username,
|
|
"password": "wrongpassword",
|
|
},
|
|
mockSetup: func(m *mockUserStore) {
|
|
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
return testUser, nil
|
|
}
|
|
},
|
|
wantStatus: http.StatusUnauthorized,
|
|
wantResponse: `{"error":"Invalid credentials"}`,
|
|
},
|
|
{
|
|
name: "successful login with user data",
|
|
includeUser: "true",
|
|
wantUserData: true,
|
|
requestBody: map[string]string{
|
|
"username": testUser.Username,
|
|
"password": validPassword,
|
|
},
|
|
mockSetup: func(m *mockUserStore) {
|
|
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
return testUser, nil
|
|
}
|
|
},
|
|
wantStatus: http.StatusOK,
|
|
checkCookie: true,
|
|
},
|
|
{
|
|
name: "successful login without user data",
|
|
includeUser: "false",
|
|
wantUserData: false,
|
|
requestBody: map[string]string{
|
|
"username": testUser.Username,
|
|
"password": validPassword,
|
|
},
|
|
mockSetup: func(m *mockUserStore) {
|
|
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
|
return testUser, nil
|
|
}
|
|
},
|
|
wantStatus: http.StatusOK,
|
|
checkCookie: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockStore := &mockUserStore{}
|
|
tt.mockSetup(mockStore)
|
|
|
|
rs := usersResource{
|
|
Users: mockStore,
|
|
JWTSecret: jwtSecret,
|
|
}
|
|
|
|
body, _ := json.Marshal(tt.requestBody)
|
|
req := httptest.NewRequest("POST", "/login", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// Add the necessary query parameters
|
|
q := url.Values{}
|
|
q.Add("includeUser", tt.includeUser)
|
|
req.URL.RawQuery = q.Encode()
|
|
|
|
w := httptest.NewRecorder()
|
|
rs.Login(w, req)
|
|
|
|
if w.Code != tt.wantStatus {
|
|
t.Errorf("expected status %d, got %d", tt.wantStatus, w.Code)
|
|
}
|
|
|
|
if tt.wantResponse != "" && strings.TrimSpace(w.Body.String()) != tt.wantResponse {
|
|
t.Errorf("expected response %q, got %q", tt.wantResponse, w.Body.String())
|
|
}
|
|
|
|
if tt.wantUserData {
|
|
var response struct {
|
|
AccessToken string `json:"access_token"`
|
|
User data.User `json:"user"` // Cast to the "raw" type to allow checking for sensitive data fields
|
|
}
|
|
json.Unmarshal(w.Body.Bytes(), &response)
|
|
|
|
assert.Equal(t, testUser.ID, response.User.ID)
|
|
assert.Equal(t, testUser.Username, response.User.Username)
|
|
assert.Empty(t, response.User.PasswordHash) // Ensure sensitive data excluded
|
|
}
|
|
|
|
if tt.checkCookie {
|
|
cookies := w.Result().Cookies()
|
|
var refreshCookie *http.Cookie
|
|
for _, cookie := range cookies {
|
|
if cookie.Name == "refresh_token" {
|
|
refreshCookie = cookie
|
|
break
|
|
}
|
|
}
|
|
|
|
if refreshCookie == nil {
|
|
t.Fatal("refresh token cookie not set")
|
|
}
|
|
|
|
assert.True(t, refreshCookie.HttpOnly, "cookie should be HttpOnly")
|
|
assert.Equal(t, http.SameSiteStrictMode, refreshCookie.SameSite, "invalid SameSite mode")
|
|
assert.Equal(t, "/", refreshCookie.Path, "invalid cookie path")
|
|
assert.Greater(t, refreshCookie.MaxAge, 0, "cookie should have expiration")
|
|
|
|
// Validate access token in response
|
|
var response map[string]string
|
|
json.Unmarshal(w.Body.Bytes(), &response)
|
|
if response["access_token"] == "" {
|
|
t.Error("access token not in response")
|
|
}
|
|
|
|
// Verify JWT validity
|
|
token, err := jwt.ParseWithClaims(
|
|
response["access_token"],
|
|
&userClaims{},
|
|
func(token *jwt.Token) (any, error) {
|
|
return []byte(jwtSecret), nil
|
|
},
|
|
)
|
|
assert.NoError(t, err, "invalid JWT")
|
|
assert.True(t, token.Valid, "invalid JWT")
|
|
}
|
|
})
|
|
}
|
|
}
|