qnote/server/internal/service/middleware_test.go

605 lines
15 KiB
Go

package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.umbrella.haus/ae/notatest/internal/data"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
type mockNoteStore struct {
CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error)
DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error
GetFullNoteFunc func(context.Context, uuid.UUID) (data.GetFullNoteRow, error)
ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.ListNotesRow, error)
CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) error
GetVersionFunc func(context.Context, data.GetVersionParams) (data.GetVersionRow, error)
GetVersionHistoryFunc func(context.Context, data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error)
}
func (m *mockNoteStore) CreateNote(ctx context.Context, id uuid.UUID) (data.Note, error) {
return m.CreateNoteFunc(ctx, id)
}
func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error {
return m.DeleteNoteFunc(ctx, arg)
}
func (m *mockNoteStore) GetFullNote(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
return m.GetFullNoteFunc(ctx, id)
}
func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error) {
return m.ListNotesFunc(ctx, arg)
}
func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error {
return m.CreateNoteVersionFunc(ctx, arg)
}
func (m *mockNoteStore) GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error) {
return m.GetVersionFunc(ctx, arg)
}
func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) {
return m.GetVersionHistoryFunc(ctx, arg)
}
func TestRequireRTMiddleware(t *testing.T) {
secret := "test-jwt-secret"
testUserID := uuid.New().String()
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
validAT := generateTestToken(t, secret, "access", testUserID, true)
expiredRT := generateTestToken(t, secret, "refresh", testUserID, true, func(claims *userClaims) {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
})
tests := []struct {
name string
token string
expectedErr string
statusCode int
}{
{
"no token",
"",
"Unauthorized",
http.StatusUnauthorized,
},
{
"invalid token",
"invalid",
"Invalid token",
http.StatusUnauthorized,
},
{
"expired token",
expiredRT,
"Invalid token",
http.StatusUnauthorized,
},
{
"wrong token type",
validAT,
"Invalid token type",
http.StatusUnauthorized,
},
{
"valid token",
validRT,
"",
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
rtAuthMiddleware := requireRefreshToken(secret)
// Mock request with cookie
req := httptest.NewRequest("GET", "/", nil)
if tc.token != "" {
req.AddCookie(&http.Cookie{
Name: "notatest.refresh_token",
Value: tc.token,
HttpOnly: true,
})
}
w := httptest.NewRecorder()
called := false
// Mock endpoint that the middleware protects
handler := rtAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
assert.True(t, ok)
assert.Equal(t, "refresh", claims.TokenType)
assert.Equal(t, testUserID, claims.Subject)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func TestRequireATMiddleware(t *testing.T) {
secret := "test-jwt-secret"
testUserID := uuid.New().String()
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
validAT := generateTestToken(t, secret, "access", testUserID, true)
expiredAT := generateTestToken(t, secret, "access", testUserID, true, func(claims *userClaims) {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
})
tests := []struct {
name string
token string
expectedErr string
statusCode int
}{
{
"no token",
"",
"Unauthorized",
http.StatusUnauthorized,
},
{
"invalid token",
"invalid",
"Invalid token",
http.StatusUnauthorized,
},
{
"expired token",
expiredAT,
"Invalid token",
http.StatusUnauthorized,
},
{
"wrong token type",
validRT,
"Invalid token type",
http.StatusUnauthorized,
},
{
"valid token",
validAT,
"",
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
atAuthMiddleware := requireAccessToken(secret)
// Mock request
req := httptest.NewRequest("GET", "/", nil)
if tc.token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
}
w := httptest.NewRecorder()
called := false
// Mock endpoint that the middleware protects
handler := atAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
assert.True(t, ok)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func TestAdminOnlyMiddleware(t *testing.T) {
tests := []struct {
name string
user *userClaims
statusCode int
}{
{
"no user",
nil,
http.StatusForbidden,
},
{
"non admin user",
&userClaims{
Admin: false,
},
http.StatusForbidden,
},
{
"admin user",
&userClaims{
Admin: true,
},
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Mock request
req := httptest.NewRequest("GET", "/", nil)
if tc.user != nil {
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
}
w := httptest.NewRecorder()
called := false
// Mock endpoint that the middleware protects
handler := adminOnlyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
})
}
}
func TestUUIDCtxMiddleware(t *testing.T) {
testKeyName := "testKey"
tests := []struct {
name string
parameter string
expectedErr string
statusCode int
}{
{
"missing uuid",
"",
"Invalid resource ID",
http.StatusBadRequest,
},
{
"invalid uuid",
"invalid",
"Invalid resource ID",
http.StatusBadRequest,
},
{
"valid uuid",
uuid.New().String(),
"",
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
uuidCtxMiddleware := uuidCtx(testKeyName)
req := httptest.NewRequest("GET", "/", nil)
// We need to mock the URL parameter as we don't setup an actual router in this test env.
rctx := chi.NewRouteContext()
rctx.URLParams.Add(testKeyName, tc.parameter)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
w := httptest.NewRecorder()
called := false
// Mock endpoint that the middleware protects
handler := uuidCtxMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
_, ok := r.Context().Value(uuidCtxKey{Name: testKeyName}).(uuid.UUID)
assert.True(t, ok)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func TestNoteCtxMiddleware(t *testing.T) {
testTitle := "Test title"
tesTContent := "## Test content\nData 123"
testVersion := int32(3)
noteID := uuid.New()
ownerUserID := uuid.New()
testOwnerClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: ownerUserID.String(),
}}
otherUserID := uuid.New()
testOtherClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: otherUserID.String(),
}}
tests := []struct {
name string
resourceID *uuid.UUID
user *userClaims
mock func(*mockNoteStore)
statusCode int
expectedErr string
}{
{
"no resource id",
nil,
nil,
func(m *mockNoteStore) {},
http.StatusBadRequest,
"Resource ID missing",
},
{
"unauthorized",
&noteID,
nil,
func(m *mockNoteStore) {},
http.StatusUnauthorized,
"Unauthorized",
},
{
"note not found",
&noteID,
&testOwnerClaims,
func(m *mockNoteStore) {
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
return data.GetFullNoteRow{}, errors.New("not found")
}
},
http.StatusNotFound,
"Note not found",
},
{
"not owner",
&noteID,
&testOtherClaims,
func(m *mockNoteStore) {
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
assert.Equal(t, noteID, id)
testTs := time.Now()
return data.GetFullNoteRow{
NoteID: id,
OwnerID: ownerUserID,
Title: testTitle,
Content: tesTContent,
VersionNumber: testVersion,
VersionCreatedAt: &testTs,
NoteCreatedAt: &testTs,
NoteUpdatedAt: &testTs,
}, nil
}
},
http.StatusForbidden,
"Forbidden",
},
{
"success",
&noteID,
&testOwnerClaims,
func(m *mockNoteStore) {
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
assert.Equal(t, noteID, id)
testTs := time.Now()
return data.GetFullNoteRow{
NoteID: id,
OwnerID: ownerUserID,
Title: testTitle,
Content: tesTContent,
VersionNumber: testVersion,
VersionCreatedAt: &testTs,
NoteCreatedAt: &testTs,
NoteUpdatedAt: &testTs,
}, nil
}
},
http.StatusOK,
"",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// UUID ctx. (mock) -> note ctx. (tested here)
mockStore := &mockNoteStore{}
req := httptest.NewRequest("GET", "/", nil)
tc.mock(mockStore)
// Mock endpoint that the middleware protects (where the attached note data is actually utilized)
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
assert.True(t, ok)
assert.Equal(t, noteID, fullNote.NoteID)
assert.Equal(t, testTitle, fullNote.Title)
assert.Equal(t, tesTContent, fullNote.Content)
assert.Equal(t, testVersion, fullNote.VersionNumber)
w.WriteHeader(http.StatusOK)
}))
// Request parameters don't need to be mocked, as parsing of them isn't handled
// by this middleware, and thus that portion shouldn't be tested here.
if tc.resourceID != nil {
req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: noteUUIDCtxParameter}, *tc.resourceID))
}
if tc.user != nil {
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func TestVersionCtxMiddleware(t *testing.T) {
testTitle := "Test title"
tesTContent := "## Test content\nData 123"
testVersion := int32(3)
versionID := uuid.New()
noteID := uuid.New()
testNote := data.GetFullNoteRow{
NoteID: noteID,
}
tests := []struct {
name string
resourceID *uuid.UUID
note *data.GetFullNoteRow
mock func(*mockNoteStore)
statusCode int
expectedErr string
}{
{
"no note",
nil,
nil,
func(m *mockNoteStore) {},
http.StatusNotFound,
"Note not found",
},
{
"no resource id",
nil,
&testNote,
func(m *mockNoteStore) {},
http.StatusBadRequest,
"Resource ID missing",
},
{
"version not found",
&versionID,
&testNote,
func(m *mockNoteStore) {
m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) {
return data.GetVersionRow{}, errors.New("not found")
}
},
http.StatusNotFound,
"Version not found",
},
{
"success",
&versionID,
&testNote,
func(m *mockNoteStore) {
m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) {
assert.Equal(t, versionID, gvp.ID)
assert.Equal(t, noteID, gvp.NoteID)
testTs := time.Now()
return data.GetVersionRow{
VersionID: gvp.ID,
Title: testTitle,
Content: tesTContent,
VersionNumber: testVersion,
CreatedAt: &testTs,
}, nil
}
},
http.StatusOK,
"",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// note ctx. (mock) -> UUID ctx. (mock) -> version ctx. (tested here)
mockStore := &mockNoteStore{}
req := httptest.NewRequest("GET", "/", nil)
tc.mock(mockStore)
// Mock endpoint that the middleware protects (where the attached note data is actually utilized)
handler := versionCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fullVersion, ok := r.Context().Value(versionCtxKey{}).(*data.GetVersionRow)
assert.True(t, ok)
assert.Equal(t, versionID, fullVersion.VersionID)
assert.Equal(t, testTitle, fullVersion.Title)
assert.Equal(t, tesTContent, fullVersion.Content)
assert.Equal(t, testVersion, fullVersion.VersionNumber)
w.WriteHeader(http.StatusOK)
}))
// Request parameters don't need to be mocked, as parsing of them isn't handled
// by this middleware, and thus that portion shouldn't be tested here.
if tc.note != nil {
req = req.WithContext(context.WithValue(req.Context(), noteCtxKey{}, tc.note))
}
if tc.resourceID != nil {
req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: versionUUIDCtxParameter}, *tc.resourceID))
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
})
}
}
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
t.Helper()
claims := &userClaims{
Admin: isAdmin,
TokenType: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
}
for _, opt := range opts {
opt(claims)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
t.Fatalf("Failed to generate test token: %v", err)
}
return signedToken
}