604 lines
15 KiB
Go
604 lines
15 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.umbrella.haus/ae/notatest/internal/data"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
type mockNoteStore struct {
|
|
CreateNoteFunc func(context.Context, uuid.UUID) (data.Note, error)
|
|
DeleteNoteFunc func(context.Context, data.DeleteNoteParams) error
|
|
GetFullNoteFunc func(context.Context, uuid.UUID) (data.GetFullNoteRow, error)
|
|
ListNotesFunc func(context.Context, data.ListNotesParams) ([]data.ListNotesRow, error)
|
|
CreateNoteVersionFunc func(context.Context, data.CreateNoteVersionParams) error
|
|
GetVersionFunc func(context.Context, data.GetVersionParams) (data.GetVersionRow, error)
|
|
GetVersionHistoryFunc func(context.Context, data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error)
|
|
}
|
|
|
|
func (m *mockNoteStore) CreateNote(ctx context.Context, id uuid.UUID) (data.Note, error) {
|
|
return m.CreateNoteFunc(ctx, id)
|
|
}
|
|
|
|
func (m *mockNoteStore) DeleteNote(ctx context.Context, arg data.DeleteNoteParams) error {
|
|
return m.DeleteNoteFunc(ctx, arg)
|
|
}
|
|
|
|
func (m *mockNoteStore) GetFullNote(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
|
|
return m.GetFullNoteFunc(ctx, id)
|
|
}
|
|
|
|
func (m *mockNoteStore) ListNotes(ctx context.Context, arg data.ListNotesParams) ([]data.ListNotesRow, error) {
|
|
return m.ListNotesFunc(ctx, arg)
|
|
}
|
|
|
|
func (m *mockNoteStore) CreateNoteVersion(ctx context.Context, arg data.CreateNoteVersionParams) error {
|
|
return m.CreateNoteVersionFunc(ctx, arg)
|
|
}
|
|
|
|
func (m *mockNoteStore) GetVersion(ctx context.Context, arg data.GetVersionParams) (data.GetVersionRow, error) {
|
|
return m.GetVersionFunc(ctx, arg)
|
|
}
|
|
|
|
func (m *mockNoteStore) GetVersionHistory(ctx context.Context, arg data.GetVersionHistoryParams) ([]data.GetVersionHistoryRow, error) {
|
|
return m.GetVersionHistoryFunc(ctx, arg)
|
|
}
|
|
|
|
func TestRequireRTMiddleware(t *testing.T) {
|
|
secret := "test-jwt-secret"
|
|
testUserID := uuid.New().String()
|
|
|
|
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
|
|
validAT := generateTestToken(t, secret, "access", testUserID, true)
|
|
expiredRT := generateTestToken(t, secret, "refresh", testUserID, true, func(claims *userClaims) {
|
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
expectedErr string
|
|
statusCode int
|
|
}{
|
|
{
|
|
"no token",
|
|
"",
|
|
"Unauthorized",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"invalid token",
|
|
"invalid",
|
|
"Invalid token",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"expired token",
|
|
expiredRT,
|
|
"Invalid token",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"wrong token type",
|
|
validAT,
|
|
"Invalid token type",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"valid token",
|
|
validRT,
|
|
"",
|
|
http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
rtAuthMiddleware := requireRefreshToken(secret)
|
|
|
|
// Mock request with cookie
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
if tc.token != "" {
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "refresh_token",
|
|
Value: tc.token,
|
|
})
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
called := false
|
|
|
|
// Mock endpoint that the middleware protects
|
|
handler := rtAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
claims, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, "refresh", claims.TokenType)
|
|
assert.Equal(t, testUserID, claims.Subject)
|
|
}))
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
if tc.expectedErr != "" {
|
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRequireATMiddleware(t *testing.T) {
|
|
secret := "test-jwt-secret"
|
|
testUserID := uuid.New().String()
|
|
|
|
validRT := generateTestToken(t, secret, "refresh", testUserID, true)
|
|
validAT := generateTestToken(t, secret, "access", testUserID, true)
|
|
expiredAT := generateTestToken(t, secret, "access", testUserID, true, func(claims *userClaims) {
|
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
expectedErr string
|
|
statusCode int
|
|
}{
|
|
{
|
|
"no token",
|
|
"",
|
|
"Unauthorized",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"invalid token",
|
|
"invalid",
|
|
"Invalid token",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"expired token",
|
|
expiredAT,
|
|
"Invalid token",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"wrong token type",
|
|
validRT,
|
|
"Invalid token type",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"valid token",
|
|
validAT,
|
|
"",
|
|
http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
atAuthMiddleware := requireAccessToken(secret)
|
|
|
|
// Mock request
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
if tc.token != "" {
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
called := false
|
|
|
|
// Mock endpoint that the middleware protects
|
|
handler := atAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
assert.True(t, ok)
|
|
}))
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
if tc.expectedErr != "" {
|
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAdminOnlyMiddleware(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
user *userClaims
|
|
statusCode int
|
|
}{
|
|
{
|
|
"no user",
|
|
nil,
|
|
http.StatusForbidden,
|
|
},
|
|
{
|
|
"non admin user",
|
|
&userClaims{
|
|
Admin: false,
|
|
},
|
|
http.StatusForbidden,
|
|
},
|
|
{
|
|
"admin user",
|
|
&userClaims{
|
|
Admin: true,
|
|
},
|
|
http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Mock request
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
if tc.user != nil {
|
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
called := false
|
|
|
|
// Mock endpoint that the middleware protects
|
|
handler := adminOnlyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUUIDCtxMiddleware(t *testing.T) {
|
|
testKeyName := "testKey"
|
|
tests := []struct {
|
|
name string
|
|
parameter string
|
|
expectedErr string
|
|
statusCode int
|
|
}{
|
|
{
|
|
"missing uuid",
|
|
"",
|
|
"Invalid resource ID",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"invalid uuid",
|
|
"invalid",
|
|
"Invalid resource ID",
|
|
http.StatusBadRequest,
|
|
},
|
|
{
|
|
"valid uuid",
|
|
uuid.New().String(),
|
|
"",
|
|
http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
uuidCtxMiddleware := uuidCtx(testKeyName)
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
|
|
// We need to mock the URL parameter as we don't setup an actual router in this test env.
|
|
rctx := chi.NewRouteContext()
|
|
rctx.URLParams.Add(testKeyName, tc.parameter)
|
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
|
|
|
w := httptest.NewRecorder()
|
|
called := false
|
|
|
|
// Mock endpoint that the middleware protects
|
|
handler := uuidCtxMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
_, ok := r.Context().Value(uuidCtxKey{Name: testKeyName}).(uuid.UUID)
|
|
assert.True(t, ok)
|
|
}))
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
if tc.expectedErr != "" {
|
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNoteCtxMiddleware(t *testing.T) {
|
|
testTitle := "Test title"
|
|
tesTContent := "## Test content\nData 123"
|
|
testVersion := int32(3)
|
|
noteID := uuid.New()
|
|
|
|
ownerUserID := uuid.New()
|
|
testOwnerClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: ownerUserID.String(),
|
|
}}
|
|
|
|
otherUserID := uuid.New()
|
|
testOtherClaims := userClaims{RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: otherUserID.String(),
|
|
}}
|
|
|
|
tests := []struct {
|
|
name string
|
|
resourceID *uuid.UUID
|
|
user *userClaims
|
|
mock func(*mockNoteStore)
|
|
statusCode int
|
|
expectedErr string
|
|
}{
|
|
{
|
|
"no resource id",
|
|
nil,
|
|
nil,
|
|
func(m *mockNoteStore) {},
|
|
http.StatusBadRequest,
|
|
"Resource ID missing",
|
|
},
|
|
{
|
|
"unauthorized",
|
|
¬eID,
|
|
nil,
|
|
func(m *mockNoteStore) {},
|
|
http.StatusUnauthorized,
|
|
"Unauthorized",
|
|
},
|
|
{
|
|
"note not found",
|
|
¬eID,
|
|
&testOwnerClaims,
|
|
func(m *mockNoteStore) {
|
|
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
|
|
return data.GetFullNoteRow{}, errors.New("not found")
|
|
}
|
|
},
|
|
http.StatusNotFound,
|
|
"Note not found",
|
|
},
|
|
{
|
|
"not owner",
|
|
¬eID,
|
|
&testOtherClaims,
|
|
func(m *mockNoteStore) {
|
|
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
|
|
assert.Equal(t, noteID, id)
|
|
testTs := time.Now()
|
|
return data.GetFullNoteRow{
|
|
NoteID: id,
|
|
OwnerID: ownerUserID,
|
|
Title: testTitle,
|
|
Content: tesTContent,
|
|
VersionNumber: testVersion,
|
|
VersionCreatedAt: &testTs,
|
|
NoteCreatedAt: &testTs,
|
|
NoteUpdatedAt: &testTs,
|
|
}, nil
|
|
}
|
|
},
|
|
http.StatusForbidden,
|
|
"Forbidden",
|
|
},
|
|
{
|
|
"success",
|
|
¬eID,
|
|
&testOwnerClaims,
|
|
func(m *mockNoteStore) {
|
|
m.GetFullNoteFunc = func(ctx context.Context, id uuid.UUID) (data.GetFullNoteRow, error) {
|
|
assert.Equal(t, noteID, id)
|
|
testTs := time.Now()
|
|
return data.GetFullNoteRow{
|
|
NoteID: id,
|
|
OwnerID: ownerUserID,
|
|
Title: testTitle,
|
|
Content: tesTContent,
|
|
VersionNumber: testVersion,
|
|
VersionCreatedAt: &testTs,
|
|
NoteCreatedAt: &testTs,
|
|
NoteUpdatedAt: &testTs,
|
|
}, nil
|
|
}
|
|
},
|
|
http.StatusOK,
|
|
"",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// UUID ctx. (mock) -> note ctx. (tested here)
|
|
mockStore := &mockNoteStore{}
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
tc.mock(mockStore)
|
|
|
|
// Mock endpoint that the middleware protects (where the attached note data is actually utilized)
|
|
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fullNote, ok := r.Context().Value(noteCtxKey{}).(*data.GetFullNoteRow)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, noteID, fullNote.NoteID)
|
|
assert.Equal(t, testTitle, fullNote.Title)
|
|
assert.Equal(t, tesTContent, fullNote.Content)
|
|
assert.Equal(t, testVersion, fullNote.VersionNumber)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Request parameters don't need to be mocked, as parsing of them isn't handled
|
|
// by this middleware, and thus that portion shouldn't be tested here.
|
|
|
|
if tc.resourceID != nil {
|
|
req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: noteUUIDCtxParameter}, *tc.resourceID))
|
|
}
|
|
|
|
if tc.user != nil {
|
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
if tc.expectedErr != "" {
|
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestVersionCtxMiddleware(t *testing.T) {
|
|
testTitle := "Test title"
|
|
tesTContent := "## Test content\nData 123"
|
|
testVersion := int32(3)
|
|
versionID := uuid.New()
|
|
|
|
noteID := uuid.New()
|
|
testNote := data.GetFullNoteRow{
|
|
NoteID: noteID,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
resourceID *uuid.UUID
|
|
note *data.GetFullNoteRow
|
|
mock func(*mockNoteStore)
|
|
statusCode int
|
|
expectedErr string
|
|
}{
|
|
{
|
|
"no note",
|
|
nil,
|
|
nil,
|
|
func(m *mockNoteStore) {},
|
|
http.StatusNotFound,
|
|
"Note not found",
|
|
},
|
|
{
|
|
"no resource id",
|
|
nil,
|
|
&testNote,
|
|
func(m *mockNoteStore) {},
|
|
http.StatusBadRequest,
|
|
"Resource ID missing",
|
|
},
|
|
{
|
|
"version not found",
|
|
&versionID,
|
|
&testNote,
|
|
func(m *mockNoteStore) {
|
|
m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) {
|
|
return data.GetVersionRow{}, errors.New("not found")
|
|
}
|
|
},
|
|
http.StatusNotFound,
|
|
"Version not found",
|
|
},
|
|
{
|
|
"success",
|
|
&versionID,
|
|
&testNote,
|
|
func(m *mockNoteStore) {
|
|
m.GetVersionFunc = func(ctx context.Context, gvp data.GetVersionParams) (data.GetVersionRow, error) {
|
|
assert.Equal(t, versionID, gvp.ID)
|
|
assert.Equal(t, noteID, gvp.NoteID)
|
|
testTs := time.Now()
|
|
return data.GetVersionRow{
|
|
VersionID: gvp.ID,
|
|
Title: testTitle,
|
|
Content: tesTContent,
|
|
VersionNumber: testVersion,
|
|
CreatedAt: &testTs,
|
|
}, nil
|
|
}
|
|
},
|
|
http.StatusOK,
|
|
"",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// note ctx. (mock) -> UUID ctx. (mock) -> version ctx. (tested here)
|
|
mockStore := &mockNoteStore{}
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
tc.mock(mockStore)
|
|
|
|
// Mock endpoint that the middleware protects (where the attached note data is actually utilized)
|
|
handler := versionCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fullVersion, ok := r.Context().Value(versionCtxKey{}).(*data.GetVersionRow)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, versionID, fullVersion.VersionID)
|
|
assert.Equal(t, testTitle, fullVersion.Title)
|
|
assert.Equal(t, tesTContent, fullVersion.Content)
|
|
assert.Equal(t, testVersion, fullVersion.VersionNumber)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Request parameters don't need to be mocked, as parsing of them isn't handled
|
|
// by this middleware, and thus that portion shouldn't be tested here.
|
|
|
|
if tc.note != nil {
|
|
req = req.WithContext(context.WithValue(req.Context(), noteCtxKey{}, tc.note))
|
|
}
|
|
|
|
if tc.resourceID != nil {
|
|
req = req.WithContext(context.WithValue(req.Context(), uuidCtxKey{Name: versionUUIDCtxParameter}, *tc.resourceID))
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
if tc.expectedErr != "" {
|
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
|
t.Helper()
|
|
|
|
claims := &userClaims{
|
|
Admin: isAdmin,
|
|
TokenType: tokenType,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
},
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(claims)
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signedToken, err := token.SignedString([]byte(secret))
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate test token: %v", err)
|
|
}
|
|
|
|
return signedToken
|
|
}
|