test: middleware unit tests
This commit is contained in:
parent
c5a56c8479
commit
993b576d0d
283
server/pkg/service/middleware_test.go
Normal file
283
server/pkg/service/middleware_test.go
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthMiddleware(t *testing.T) {
|
||||||
|
secret := "test-secret"
|
||||||
|
validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
|
||||||
|
expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
|
||||||
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
expectedErr string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"no token",
|
||||||
|
"",
|
||||||
|
"Unauthorized",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid token",
|
||||||
|
"invalid",
|
||||||
|
"Invalid token",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"expired token",
|
||||||
|
expiredToken,
|
||||||
|
"Invalid token",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong type",
|
||||||
|
generateTestToken(
|
||||||
|
t,
|
||||||
|
secret,
|
||||||
|
"refresh",
|
||||||
|
uuid.New().String(),
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
"Invalid token type",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"valid token",
|
||||||
|
validToken,
|
||||||
|
"",
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mw := requireAccessToken(secret)
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
if tc.token != "" {
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
called := false
|
||||||
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}))
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminOnlyMiddleware(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
user *userClaims
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"no user",
|
||||||
|
nil,
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"non admin user",
|
||||||
|
&userClaims{
|
||||||
|
Admin: false,
|
||||||
|
},
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"admin user",
|
||||||
|
&userClaims{
|
||||||
|
Admin: true,
|
||||||
|
},
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mw := adminOnlyMiddleware
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
if tc.user != nil {
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
called := false
|
||||||
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOwnerOnlyMiddleware(t *testing.T) {
|
||||||
|
userID := uuid.New().String()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
user *userClaims
|
||||||
|
urlID string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"no user",
|
||||||
|
nil,
|
||||||
|
userID,
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"different ID",
|
||||||
|
&userClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: uuid.New().String(),
|
||||||
|
}},
|
||||||
|
userID,
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"matching ID",
|
||||||
|
&userClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
userID,
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
r.With(
|
||||||
|
// Add user with the given claims to request's context
|
||||||
|
func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var ctx context.Context = r.Context()
|
||||||
|
if tc.user != nil {
|
||||||
|
ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
ownerOnlyMiddleware,
|
||||||
|
).Get("/{id}", handlerChain)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if tc.urlID == "invalid" {
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserCtxMiddleware(t *testing.T) {
|
||||||
|
validUserID := uuid.New()
|
||||||
|
invalidUserID := "invalid"
|
||||||
|
|
||||||
|
mockStore := &mockUserStore{
|
||||||
|
GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
|
||||||
|
if id == validUserID {
|
||||||
|
return data.User{ID: validUserID}, nil
|
||||||
|
}
|
||||||
|
return data.User{}, errors.New("not found")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
urlID string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"valid ID", validUserID.String(), http.StatusOK},
|
||||||
|
{"invalid ID", invalidUserID, http.StatusNotFound},
|
||||||
|
{"non existent ID", uuid.New().String(), http.StatusNotFound},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mw := userCtx(mockStore)
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, validUserID, user.ID)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.statusCode, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
claims := &userClaims{
|
||||||
|
Admin: isAdmin,
|
||||||
|
TokenType: tokenType,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
signedToken, err := token.SignedString([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate test token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signedToken
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user