test: middleware unit tests
This commit is contained in:
parent
c5a56c8479
commit
993b576d0d
283
server/pkg/service/middleware_test.go
Normal file
283
server/pkg/service/middleware_test.go
Normal file
@ -0,0 +1,283 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
secret := "test-secret"
|
||||
validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
|
||||
expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
|
||||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectedErr string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"no token",
|
||||
"",
|
||||
"Unauthorized",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"invalid token",
|
||||
"invalid",
|
||||
"Invalid token",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"expired token",
|
||||
expiredToken,
|
||||
"Invalid token",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"wrong type",
|
||||
generateTestToken(
|
||||
t,
|
||||
secret,
|
||||
"refresh",
|
||||
uuid.New().String(),
|
||||
true,
|
||||
),
|
||||
"Invalid token type",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"valid token",
|
||||
validToken,
|
||||
"",
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mw := requireAccessToken(secret)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.token != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
called := false
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
assert.True(t, ok)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
if tc.expectedErr != "" {
|
||||
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||
}
|
||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminOnlyMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *userClaims
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"no user",
|
||||
nil,
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"non admin user",
|
||||
&userClaims{
|
||||
Admin: false,
|
||||
},
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"admin user",
|
||||
&userClaims{
|
||||
Admin: true,
|
||||
},
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mw := adminOnlyMiddleware
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.user != nil {
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
called := false
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOwnerOnlyMiddleware(t *testing.T) {
|
||||
userID := uuid.New().String()
|
||||
tests := []struct {
|
||||
name string
|
||||
user *userClaims
|
||||
urlID string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"no user",
|
||||
nil,
|
||||
userID,
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"different ID",
|
||||
&userClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: uuid.New().String(),
|
||||
}},
|
||||
userID,
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"matching ID",
|
||||
&userClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
},
|
||||
},
|
||||
userID,
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r.With(
|
||||
// Add user with the given claims to request's context
|
||||
func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var ctx context.Context = r.Context()
|
||||
if tc.user != nil {
|
||||
ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
},
|
||||
ownerOnlyMiddleware,
|
||||
).Get("/{id}", handlerChain)
|
||||
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if tc.urlID == "invalid" {
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
} else {
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserCtxMiddleware(t *testing.T) {
|
||||
validUserID := uuid.New()
|
||||
invalidUserID := "invalid"
|
||||
|
||||
mockStore := &mockUserStore{
|
||||
GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
|
||||
if id == validUserID {
|
||||
return data.User{ID: validUserID}, nil
|
||||
}
|
||||
return data.User{}, errors.New("not found")
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlID string
|
||||
statusCode int
|
||||
}{
|
||||
{"valid ID", validUserID.String(), http.StatusOK},
|
||||
{"invalid ID", invalidUserID, http.StatusNotFound},
|
||||
{"non existent ID", uuid.New().String(), http.StatusNotFound},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mw := userCtx(mockStore)
|
||||
r := chi.NewRouter()
|
||||
r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, validUserID, user.ID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
||||
t.Helper()
|
||||
|
||||
claims := &userClaims{
|
||||
Admin: isAdmin,
|
||||
TokenType: tokenType,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(claims)
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test token: %v", err)
|
||||
}
|
||||
|
||||
return signedToken
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user