Compare commits
4 Commits
3257b19313
...
9324bb5321
Author | SHA1 | Date | |
---|---|---|---|
9324bb5321 | |||
993b576d0d | |||
c5a56c8479 | |||
2b65bf70d8 |
@ -78,7 +78,7 @@ func ownerOnlyMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
requestedID := chi.URLParam(r, "id")
|
||||
if !ok || user.ID != requestedID {
|
||||
if !ok || user.Subject != requestedID {
|
||||
respondError(w, http.StatusForbidden, "Forbidden")
|
||||
return
|
||||
}
|
||||
|
283
server/pkg/service/middleware_test.go
Normal file
283
server/pkg/service/middleware_test.go
Normal file
@ -0,0 +1,283 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
secret := "test-secret"
|
||||
validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
|
||||
expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
|
||||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectedErr string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"no token",
|
||||
"",
|
||||
"Unauthorized",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"invalid token",
|
||||
"invalid",
|
||||
"Invalid token",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"expired token",
|
||||
expiredToken,
|
||||
"Invalid token",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"wrong type",
|
||||
generateTestToken(
|
||||
t,
|
||||
secret,
|
||||
"refresh",
|
||||
uuid.New().String(),
|
||||
true,
|
||||
),
|
||||
"Invalid token type",
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"valid token",
|
||||
validToken,
|
||||
"",
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mw := requireAccessToken(secret)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.token != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
called := false
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||
assert.True(t, ok)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
if tc.expectedErr != "" {
|
||||
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
||||
}
|
||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminOnlyMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *userClaims
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"no user",
|
||||
nil,
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"non admin user",
|
||||
&userClaims{
|
||||
Admin: false,
|
||||
},
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"admin user",
|
||||
&userClaims{
|
||||
Admin: true,
|
||||
},
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mw := adminOnlyMiddleware
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.user != nil {
|
||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
called := false
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOwnerOnlyMiddleware(t *testing.T) {
|
||||
userID := uuid.New().String()
|
||||
tests := []struct {
|
||||
name string
|
||||
user *userClaims
|
||||
urlID string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"no user",
|
||||
nil,
|
||||
userID,
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"different ID",
|
||||
&userClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: uuid.New().String(),
|
||||
}},
|
||||
userID,
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"matching ID",
|
||||
&userClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
},
|
||||
},
|
||||
userID,
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r.With(
|
||||
// Add user with the given claims to request's context
|
||||
func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var ctx context.Context = r.Context()
|
||||
if tc.user != nil {
|
||||
ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
},
|
||||
ownerOnlyMiddleware,
|
||||
).Get("/{id}", handlerChain)
|
||||
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if tc.urlID == "invalid" {
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
} else {
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserCtxMiddleware(t *testing.T) {
|
||||
validUserID := uuid.New()
|
||||
invalidUserID := "invalid"
|
||||
|
||||
mockStore := &mockUserStore{
|
||||
GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
|
||||
if id == validUserID {
|
||||
return data.User{ID: validUserID}, nil
|
||||
}
|
||||
return data.User{}, errors.New("not found")
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlID string
|
||||
statusCode int
|
||||
}{
|
||||
{"valid ID", validUserID.String(), http.StatusOK},
|
||||
{"invalid ID", invalidUserID, http.StatusNotFound},
|
||||
{"non existent ID", uuid.New().String(), http.StatusNotFound},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mw := userCtx(mockStore)
|
||||
r := chi.NewRouter()
|
||||
r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, validUserID, user.ID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.statusCode, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
||||
t.Helper()
|
||||
|
||||
claims := &userClaims{
|
||||
Admin: isAdmin,
|
||||
TokenType: tokenType,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(claims)
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test token: %v", err)
|
||||
}
|
||||
|
||||
return signedToken
|
||||
}
|
@ -78,9 +78,18 @@ func TestCreateUser_InvalidUsername(t *testing.T) {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"Too short", `{"username": "a", "password": "validPass123!"}`},
|
||||
{"Invalid chars", `{"username": "user@name", "password": "validPass123!"}`},
|
||||
{"Empty", `{"username": "", "password": "validPass123!"}`},
|
||||
{
|
||||
"too short",
|
||||
`{"username": "a", "password": "validPass123!"}`,
|
||||
},
|
||||
{
|
||||
"invalid chars",
|
||||
`{"username": "user@name", "password": "validPass123!"}`,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
`{"username": "", "password": "validPass123!"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@ -101,9 +110,18 @@ func TestCreateUser_InvalidPassword(t *testing.T) {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"too short", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1))},
|
||||
{"too long", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1))},
|
||||
{"low entropy", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength))},
|
||||
{
|
||||
"too short",
|
||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1)),
|
||||
},
|
||||
{
|
||||
"too long",
|
||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1)),
|
||||
},
|
||||
{
|
||||
"low entropy",
|
||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
@ -77,7 +77,7 @@ func TestIsPasswordCompromised(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("randomly generated", func(t *testing.T) {
|
||||
randomStr := genRandomString(12)
|
||||
randomStr := genRandomString(t, 12)
|
||||
compromised, err := isPasswordCompromised(randomStr)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, compromised)
|
||||
@ -173,7 +173,9 @@ func TestNormalizeUsername(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func genRandomString(length int) string {
|
||||
func genRandomString(t *testing.T, length int) string {
|
||||
t.Helper()
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
b := make([]byte, length)
|
||||
|
Loading…
x
Reference in New Issue
Block a user