284 lines
6.0 KiB
Go
284 lines
6.0 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestAuthMiddleware(t *testing.T) {
|
|
secret := "test-secret"
|
|
validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
|
|
expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
|
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
expectedErr string
|
|
statusCode int
|
|
}{
|
|
{
|
|
"no token",
|
|
"",
|
|
"Unauthorized",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"invalid token",
|
|
"invalid",
|
|
"Invalid token",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"expired token",
|
|
expiredToken,
|
|
"Invalid token",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"wrong type",
|
|
generateTestToken(
|
|
t,
|
|
secret,
|
|
"refresh",
|
|
uuid.New().String(),
|
|
true,
|
|
),
|
|
"Invalid token type",
|
|
http.StatusUnauthorized,
|
|
},
|
|
{
|
|
"valid token",
|
|
validToken,
|
|
"",
|
|
http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
mw := requireAccessToken(secret)
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
if tc.token != "" {
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
called := false
|
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
assert.True(t, ok)
|
|
}))
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
if tc.expectedErr != "" {
|
|
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
|
}
|
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAdminOnlyMiddleware(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
user *userClaims
|
|
statusCode int
|
|
}{
|
|
{
|
|
"no user",
|
|
nil,
|
|
http.StatusForbidden,
|
|
},
|
|
{
|
|
"non admin user",
|
|
&userClaims{
|
|
Admin: false,
|
|
},
|
|
http.StatusForbidden,
|
|
},
|
|
{
|
|
"admin user",
|
|
&userClaims{
|
|
Admin: true,
|
|
},
|
|
http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
mw := adminOnlyMiddleware
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
if tc.user != nil {
|
|
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
called := false
|
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOwnerOnlyMiddleware(t *testing.T) {
|
|
userID := uuid.New().String()
|
|
tests := []struct {
|
|
name string
|
|
user *userClaims
|
|
urlID string
|
|
statusCode int
|
|
}{
|
|
{
|
|
"no user",
|
|
nil,
|
|
userID,
|
|
http.StatusForbidden,
|
|
},
|
|
{
|
|
"different ID",
|
|
&userClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: uuid.New().String(),
|
|
}},
|
|
userID,
|
|
http.StatusForbidden,
|
|
},
|
|
{
|
|
"matching ID",
|
|
&userClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
},
|
|
},
|
|
userID,
|
|
http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
r := chi.NewRouter()
|
|
|
|
handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
r.With(
|
|
// Add user with the given claims to request's context
|
|
func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var ctx context.Context = r.Context()
|
|
if tc.user != nil {
|
|
ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
},
|
|
ownerOnlyMiddleware,
|
|
).Get("/{id}", handlerChain)
|
|
|
|
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if tc.urlID == "invalid" {
|
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
|
} else {
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUserCtxMiddleware(t *testing.T) {
|
|
validUserID := uuid.New()
|
|
invalidUserID := "invalid"
|
|
|
|
mockStore := &mockUserStore{
|
|
GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
|
|
if id == validUserID {
|
|
return data.User{ID: validUserID}, nil
|
|
}
|
|
return data.User{}, errors.New("not found")
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
urlID string
|
|
statusCode int
|
|
}{
|
|
{"valid ID", validUserID.String(), http.StatusOK},
|
|
{"invalid ID", invalidUserID, http.StatusNotFound},
|
|
{"non existent ID", uuid.New().String(), http.StatusNotFound},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
mw := userCtx(mockStore)
|
|
r := chi.NewRouter()
|
|
r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
|
|
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, validUserID, user.ID)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.statusCode, w.Code)
|
|
})
|
|
}
|
|
}
|
|
|
|
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
|
t.Helper()
|
|
|
|
claims := &userClaims{
|
|
Admin: isAdmin,
|
|
TokenType: tokenType,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: userID,
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
},
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(claims)
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signedToken, err := token.SignedString([]byte(secret))
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate test token: %v", err)
|
|
}
|
|
|
|
return signedToken
|
|
}
|