Compare commits
2 Commits
b393f1a47c
...
7646df76df
Author | SHA1 | Date | |
---|---|---|---|
7646df76df | |||
15c4666ace |
@ -5,6 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
@ -16,6 +18,15 @@ import (
|
|||||||
|
|
||||||
type userCtxKey struct{}
|
type userCtxKey struct{}
|
||||||
|
|
||||||
|
// Stripped object that only contains non-critical data
|
||||||
|
type userResponse struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
IsAdmin bool `json:"is_admin"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
UpdatedAt *time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
// Mockable database operations interface
|
// Mockable database operations interface
|
||||||
type UserStore interface {
|
type UserStore interface {
|
||||||
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
CreateUser(ctx context.Context, arg data.CreateUserParams) (data.User, error)
|
||||||
@ -43,6 +54,8 @@ func (rs usersResource) Routes() chi.Router {
|
|||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(requireAccessToken(rs.JWTSecret))
|
r.Use(requireAccessToken(rs.JWTSecret))
|
||||||
|
|
||||||
|
r.Get("/me", rs.Get) // GET /users/me - get current user data
|
||||||
|
|
||||||
// Admin only general routes
|
// Admin only general routes
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(adminOnlyMiddleware)
|
r.Use(adminOnlyMiddleware)
|
||||||
@ -56,7 +69,6 @@ func (rs usersResource) Routes() chi.Router {
|
|||||||
// Admin routes
|
// Admin routes
|
||||||
r.Route("/admin", func(r chi.Router) {
|
r.Route("/admin", func(r chi.Router) {
|
||||||
r.Use(adminOnlyMiddleware)
|
r.Use(adminOnlyMiddleware)
|
||||||
r.Get("/", rs.Get) // GET /users/admin/{id} - get single user
|
|
||||||
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
|
r.Delete("/", rs.AdminDelete) // DELETE /users/admin/{id} - delete user
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -160,10 +172,23 @@ func (rs usersResource) Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
SameSite: http.SameSiteStrictMode,
|
SameSite: http.SameSiteStrictMode,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Return the access token in the response body (it should be stored in browser's memory client-side)
|
// Build response
|
||||||
respondJSON(w, http.StatusOK, map[string]string{
|
response := map[string]any{
|
||||||
"access_token": tokenPair.AccessToken,
|
"access_token": tokenPair.AccessToken,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Include user data if the client has requested it (`?includeUser=true`)
|
||||||
|
if includeUser, _ := strconv.ParseBool(r.URL.Query().Get("includeUser")); includeUser {
|
||||||
|
response["user"] = userResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
IsAdmin: user.IsAdmin,
|
||||||
|
CreatedAt: user.CreatedAt,
|
||||||
|
UpdatedAt: user.UpdatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respondJSON(w, http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -188,17 +213,30 @@ func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
|
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
respondError(w, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
respondError(w, http.StatusInternalServerError, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := rs.Users.GetUserByID(r.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
respondError(w, http.StatusNotFound, "User not found")
|
respondError(w, http.StatusNotFound, "User not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
respondJSON(w, http.StatusOK, map[string]any{
|
respondJSON(w, http.StatusOK, userResponse{
|
||||||
"id": user.ID,
|
ID: user.ID,
|
||||||
"username": user.Username,
|
Username: user.Username,
|
||||||
"created_at": user.CreatedAt,
|
IsAdmin: user.IsAdmin,
|
||||||
"updated_at": user.UpdatedAt,
|
CreatedAt: user.CreatedAt,
|
||||||
|
UpdatedAt: user.UpdatedAt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,8 +8,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@ -243,17 +245,116 @@ func TestOwnerDelete_InvalidCredentials(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetUser_NotFound(t *testing.T) {
|
func TestUsersGetCurrentUser(t *testing.T) {
|
||||||
mockStore := &mockUserStore{}
|
validUserID := uuid.New()
|
||||||
rs := usersResource{Users: mockStore}
|
testTime := time.Now().UTC().Truncate(time.Second)
|
||||||
|
testUser := data.User{
|
||||||
|
ID: validUserID,
|
||||||
|
Username: "testuser",
|
||||||
|
CreatedAt: &testTime,
|
||||||
|
UpdatedAt: &testTime,
|
||||||
|
IsAdmin: false,
|
||||||
|
}
|
||||||
|
|
||||||
// No user in context
|
tests := []struct {
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
name string
|
||||||
w := httptest.NewRecorder()
|
setupContext func(context.Context) context.Context
|
||||||
|
mockSetup func(*mockUserStore)
|
||||||
|
wantStatus int
|
||||||
|
wantResponse string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
setupContext: func(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: validUserID.String(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
mockSetup: func(m *mockUserStore) {
|
||||||
|
m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) {
|
||||||
|
assert.Equal(t, validUserID, id)
|
||||||
|
return testUser, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantResponse: fmt.Sprintf(
|
||||||
|
`{"created_at":"%s","id":"%s","is_admin":false,"updated_at":"%s","username":"testuser"}`,
|
||||||
|
testUser.CreatedAt.Format(time.RFC3339Nano),
|
||||||
|
validUserID.String(),
|
||||||
|
testUser.UpdatedAt.Format(time.RFC3339Nano),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user not found",
|
||||||
|
setupContext: func(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: validUserID.String(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
mockSetup: func(m *mockUserStore) {
|
||||||
|
m.GetUserByIDFunc = func(_ context.Context, id uuid.UUID) (data.User, error) {
|
||||||
|
return data.User{}, errors.New("not found")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusNotFound,
|
||||||
|
wantResponse: `{"error":"User not found"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized",
|
||||||
|
setupContext: func(ctx context.Context) context.Context {
|
||||||
|
return ctx // No user claims in context
|
||||||
|
},
|
||||||
|
mockSetup: func(m *mockUserStore) {},
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
wantResponse: `{"error":"Unauthorized"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid user ID",
|
||||||
|
setupContext: func(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, userCtxKey{}, &userClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: "invalid",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
mockSetup: func(m *mockUserStore) {},
|
||||||
|
wantStatus: http.StatusInternalServerError,
|
||||||
|
wantResponse: `{"error":"Invalid user ID"}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
rs.Get(w, req)
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockStore := &mockUserStore{}
|
||||||
|
tt.mockSetup(mockStore)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
rs := usersResource{Users: mockStore}
|
||||||
|
req := httptest.NewRequest("GET", "/me", nil)
|
||||||
|
req = req.WithContext(tt.setupContext(req.Context()))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rs.Get(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantStatus, w.Code)
|
||||||
|
|
||||||
|
if tt.wantResponse != "" {
|
||||||
|
actual := strings.TrimSpace(w.Body.String())
|
||||||
|
assert.JSONEq(t, tt.wantResponse, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify sensitive fields are never exposed
|
||||||
|
if w.Code == http.StatusOK {
|
||||||
|
var response map[string]any
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &response)
|
||||||
|
_, exists := response["password_hash"]
|
||||||
|
assert.False(t, exists, "password_hash should not be exposed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdatePassword_DatabaseError(t *testing.T) {
|
func TestUpdatePassword_DatabaseError(t *testing.T) {
|
||||||
@ -295,7 +396,9 @@ func TestUsersLogin(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
requestBody interface{}
|
includeUser string
|
||||||
|
wantUserData bool
|
||||||
|
requestBody any
|
||||||
mockSetup func(*mockUserStore)
|
mockSetup func(*mockUserStore)
|
||||||
wantStatus int
|
wantStatus int
|
||||||
wantResponse string
|
wantResponse string
|
||||||
@ -337,7 +440,25 @@ func TestUsersLogin(t *testing.T) {
|
|||||||
wantResponse: `{"error":"Invalid credentials"}`,
|
wantResponse: `{"error":"Invalid credentials"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "successful login",
|
name: "successful login with user data",
|
||||||
|
includeUser: "true",
|
||||||
|
wantUserData: true,
|
||||||
|
requestBody: map[string]string{
|
||||||
|
"username": testUser.Username,
|
||||||
|
"password": validPassword,
|
||||||
|
},
|
||||||
|
mockSetup: func(m *mockUserStore) {
|
||||||
|
m.GetUserByUsernameFunc = func(_ context.Context, username string) (data.User, error) {
|
||||||
|
return testUser, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
checkCookie: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful login without user data",
|
||||||
|
includeUser: "false",
|
||||||
|
wantUserData: false,
|
||||||
requestBody: map[string]string{
|
requestBody: map[string]string{
|
||||||
"username": testUser.Username,
|
"username": testUser.Username,
|
||||||
"password": validPassword,
|
"password": validPassword,
|
||||||
@ -366,6 +487,11 @@ func TestUsersLogin(t *testing.T) {
|
|||||||
req := httptest.NewRequest("POST", "/login", bytes.NewReader(body))
|
req := httptest.NewRequest("POST", "/login", bytes.NewReader(body))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Add the necessary query parameters
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("includeUser", tt.includeUser)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
rs.Login(w, req)
|
rs.Login(w, req)
|
||||||
|
|
||||||
@ -377,6 +503,18 @@ func TestUsersLogin(t *testing.T) {
|
|||||||
t.Errorf("expected response %q, got %q", tt.wantResponse, w.Body.String())
|
t.Errorf("expected response %q, got %q", tt.wantResponse, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tt.wantUserData {
|
||||||
|
var response struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
User data.User `json:"user"` // Cast to the "raw" type to allow checking for sensitive data fields
|
||||||
|
}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &response)
|
||||||
|
|
||||||
|
assert.Equal(t, testUser.ID, response.User.ID)
|
||||||
|
assert.Equal(t, testUser.Username, response.User.Username)
|
||||||
|
assert.Empty(t, response.User.PasswordHash) // Ensure sensitive data excluded
|
||||||
|
}
|
||||||
|
|
||||||
if tt.checkCookie {
|
if tt.checkCookie {
|
||||||
cookies := w.Result().Cookies()
|
cookies := w.Result().Cookies()
|
||||||
var refreshCookie *http.Cookie
|
var refreshCookie *http.Cookie
|
||||||
@ -407,7 +545,7 @@ func TestUsersLogin(t *testing.T) {
|
|||||||
token, err := jwt.ParseWithClaims(
|
token, err := jwt.ParseWithClaims(
|
||||||
response["access_token"],
|
response["access_token"],
|
||||||
&userClaims{},
|
&userClaims{},
|
||||||
func(token *jwt.Token) (interface{}, error) {
|
func(token *jwt.Token) (any, error) {
|
||||||
return []byte(jwtSecret), nil
|
return []byte(jwtSecret), nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user