notatest/server/pkg/service/middleware_test.go

387 lines
8.2 KiB
Go

package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.umbrella.haus/ae/notatest/pkg/data"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestAuthMiddleware(t *testing.T) {
secret := "test-secret"
validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
})
tests := []struct {
name string
token string
expectedErr string
statusCode int
}{
{
"no token",
"",
"Unauthorized",
http.StatusUnauthorized,
},
{
"invalid token",
"invalid",
"Invalid token",
http.StatusUnauthorized,
},
{
"expired token",
expiredToken,
"Invalid token",
http.StatusUnauthorized,
},
{
"wrong type",
generateTestToken(
t,
secret,
"refresh",
uuid.New().String(),
true,
),
"Invalid token type",
http.StatusUnauthorized,
},
{
"valid token",
validToken,
"",
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mw := requireAccessToken(secret)
req := httptest.NewRequest("GET", "/", nil)
if tc.token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
}
w := httptest.NewRecorder()
called := false
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
assert.True(t, ok)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
if tc.expectedErr != "" {
assert.Contains(t, w.Body.String(), tc.expectedErr)
}
assert.Equal(t, tc.statusCode == http.StatusOK, called)
})
}
}
func TestAdminOnlyMiddleware(t *testing.T) {
tests := []struct {
name string
user *userClaims
statusCode int
}{
{
"no user",
nil,
http.StatusForbidden,
},
{
"non admin user",
&userClaims{
Admin: false,
},
http.StatusForbidden,
},
{
"admin user",
&userClaims{
Admin: true,
},
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mw := adminOnlyMiddleware
req := httptest.NewRequest("GET", "/", nil)
if tc.user != nil {
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
}
w := httptest.NewRecorder()
called := false
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
assert.Equal(t, tc.statusCode == http.StatusOK, called)
})
}
}
func TestOwnerOnlyMiddleware(t *testing.T) {
userID := uuid.New().String()
tests := []struct {
name string
user *userClaims
urlID string
statusCode int
}{
{
"no user",
nil,
userID,
http.StatusForbidden,
},
{
"different ID",
&userClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: uuid.New().String(),
}},
userID,
http.StatusForbidden,
},
{
"matching ID",
&userClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
},
},
userID,
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := chi.NewRouter()
handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
r.With(
// Add user with the given claims to request's context
func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ctx context.Context = r.Context()
if tc.user != nil {
ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
},
ownerOnlyMiddleware,
).Get("/{id}", handlerChain)
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if tc.urlID == "invalid" {
assert.Equal(t, http.StatusNotFound, w.Code)
} else {
assert.Equal(t, tc.statusCode, w.Code)
}
})
}
}
func TestUserCtxMiddleware(t *testing.T) {
validUserID := uuid.New()
invalidUserID := "invalid"
mockStore := &mockUserStore{
GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
if id == validUserID {
return data.User{ID: validUserID}, nil
}
return data.User{}, errors.New("not found")
},
}
tests := []struct {
name string
urlID string
statusCode int
}{
{
"valid ID",
validUserID.String(),
http.StatusOK,
},
{
"invalid ID",
invalidUserID,
http.StatusNotFound,
},
{
"non existent ID",
uuid.New().String(),
http.StatusNotFound,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mw := userCtx(mockStore)
r := chi.NewRouter()
r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(userCtxKey{}).(data.User)
assert.True(t, ok)
assert.Equal(t, validUserID, user.ID)
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
})
}
}
func TestNoteCtxMiddleware(t *testing.T) {
userID := uuid.New()
noteID := uuid.New()
tests := []struct {
name string
noteID string
user any
mock func(*mockNoteStore)
statusCode int
}{
{
"invalid note ID",
"invalid",
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
func(m *mockNoteStore) {},
http.StatusNotFound,
},
{
"unauthorized user",
noteID.String(),
nil,
func(m *mockNoteStore) {},
http.StatusUnauthorized,
},
{
"note not found",
noteID.String(),
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
func(m *mockNoteStore) {
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
return data.Note{}, errors.New("not found")
}
},
http.StatusNotFound,
},
{
"success",
noteID.String(),
&userClaims{RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
}},
func(m *mockNoteStore) {
m.GetNoteFunc = func(ctx context.Context, arg data.GetNoteParams) (data.Note, error) {
assert.Equal(t, noteID, arg.ID)
assert.Equal(t, userID, arg.UserID)
return data.Note{ID: noteID}, nil
}
},
http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStore := &mockNoteStore{}
tc.mock(mockStore)
handler := noteCtx(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Context().Value(noteCtxKey{}).(data.Note)
assert.True(t, ok)
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", fmt.Sprintf("/notes/%s", tc.noteID), nil)
// Chi router context mocks ID passed in a URL parameter
rctx := chi.NewRouteContext()
rctx.URLParams.Add("id", tc.noteID)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
if tc.user != nil {
req = req.WithContext(context.WithValue(
req.Context(),
userCtxKey{},
tc.user,
))
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.statusCode, w.Code)
})
}
}
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
t.Helper()
claims := &userClaims{
Admin: isAdmin,
TokenType: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
}
for _, opt := range opts {
opt(claims)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
t.Fatalf("Failed to generate test token: %v", err)
}
return signedToken
}