Compare commits
No commits in common. "9324bb5321dc5b7f83aba9e06b4cc632c2a1e916" and "3257b193137abbcad67b71c5f3c2955b50aa2d9a" have entirely different histories.
9324bb5321
...
3257b19313
@ -78,7 +78,7 @@ func ownerOnlyMiddleware(next http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
user, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
||||||
requestedID := chi.URLParam(r, "id")
|
requestedID := chi.URLParam(r, "id")
|
||||||
if !ok || user.Subject != requestedID {
|
if !ok || user.ID != requestedID {
|
||||||
respondError(w, http.StatusForbidden, "Forbidden")
|
respondError(w, http.StatusForbidden, "Forbidden")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1,283 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.umbrella.haus/ae/notatest/pkg/data"
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthMiddleware(t *testing.T) {
|
|
||||||
secret := "test-secret"
|
|
||||||
validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
|
|
||||||
expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
|
|
||||||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
token string
|
|
||||||
expectedErr string
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"no token",
|
|
||||||
"",
|
|
||||||
"Unauthorized",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid token",
|
|
||||||
"invalid",
|
|
||||||
"Invalid token",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"expired token",
|
|
||||||
expiredToken,
|
|
||||||
"Invalid token",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"wrong type",
|
|
||||||
generateTestToken(
|
|
||||||
t,
|
|
||||||
secret,
|
|
||||||
"refresh",
|
|
||||||
uuid.New().String(),
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
"Invalid token type",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"valid token",
|
|
||||||
validToken,
|
|
||||||
"",
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mw := requireAccessToken(secret)
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
if tc.token != "" {
|
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
called := false
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
called = true
|
|
||||||
_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
|
|
||||||
assert.True(t, ok)
|
|
||||||
}))
|
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
if tc.expectedErr != "" {
|
|
||||||
assert.Contains(t, w.Body.String(), tc.expectedErr)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdminOnlyMiddleware(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
user *userClaims
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"no user",
|
|
||||||
nil,
|
|
||||||
http.StatusForbidden,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"non admin user",
|
|
||||||
&userClaims{
|
|
||||||
Admin: false,
|
|
||||||
},
|
|
||||||
http.StatusForbidden,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"admin user",
|
|
||||||
&userClaims{
|
|
||||||
Admin: true,
|
|
||||||
},
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mw := adminOnlyMiddleware
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
if tc.user != nil {
|
|
||||||
req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
called := false
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
called = true
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
assert.Equal(t, tc.statusCode == http.StatusOK, called)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOwnerOnlyMiddleware(t *testing.T) {
|
|
||||||
userID := uuid.New().String()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
user *userClaims
|
|
||||||
urlID string
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"no user",
|
|
||||||
nil,
|
|
||||||
userID,
|
|
||||||
http.StatusForbidden,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"different ID",
|
|
||||||
&userClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: uuid.New().String(),
|
|
||||||
}},
|
|
||||||
userID,
|
|
||||||
http.StatusForbidden,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"matching ID",
|
|
||||||
&userClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
userID,
|
|
||||||
http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
r := chi.NewRouter()
|
|
||||||
|
|
||||||
handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
r.With(
|
|
||||||
// Add user with the given claims to request's context
|
|
||||||
func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var ctx context.Context = r.Context()
|
|
||||||
if tc.user != nil {
|
|
||||||
ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
},
|
|
||||||
ownerOnlyMiddleware,
|
|
||||||
).Get("/{id}", handlerChain)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if tc.urlID == "invalid" {
|
|
||||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserCtxMiddleware(t *testing.T) {
|
|
||||||
validUserID := uuid.New()
|
|
||||||
invalidUserID := "invalid"
|
|
||||||
|
|
||||||
mockStore := &mockUserStore{
|
|
||||||
GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
|
|
||||||
if id == validUserID {
|
|
||||||
return data.User{ID: validUserID}, nil
|
|
||||||
}
|
|
||||||
return data.User{}, errors.New("not found")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
urlID string
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{"valid ID", validUserID.String(), http.StatusOK},
|
|
||||||
{"invalid ID", invalidUserID, http.StatusNotFound},
|
|
||||||
{"non existent ID", uuid.New().String(), http.StatusNotFound},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
mw := userCtx(mockStore)
|
|
||||||
r := chi.NewRouter()
|
|
||||||
r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user, ok := r.Context().Value(userCtxKey{}).(data.User)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, validUserID, user.ID)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.statusCode, w.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
claims := &userClaims{
|
|
||||||
Admin: isAdmin,
|
|
||||||
TokenType: tokenType,
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: userID,
|
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
||||||
signedToken, err := token.SignedString([]byte(secret))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to generate test token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return signedToken
|
|
||||||
}
|
|
@ -78,18 +78,9 @@ func TestCreateUser_InvalidUsername(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
body string
|
body string
|
||||||
}{
|
}{
|
||||||
{
|
{"Too short", `{"username": "a", "password": "validPass123!"}`},
|
||||||
"too short",
|
{"Invalid chars", `{"username": "user@name", "password": "validPass123!"}`},
|
||||||
`{"username": "a", "password": "validPass123!"}`,
|
{"Empty", `{"username": "", "password": "validPass123!"}`},
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid chars",
|
|
||||||
`{"username": "user@name", "password": "validPass123!"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"empty",
|
|
||||||
`{"username": "", "password": "validPass123!"}`,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
@ -110,18 +101,9 @@ func TestCreateUser_InvalidPassword(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
body string
|
body string
|
||||||
}{
|
}{
|
||||||
{
|
{"too short", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1))},
|
||||||
"too short",
|
{"too long", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1))},
|
||||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength-1)),
|
{"low entropy", fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength))},
|
||||||
},
|
|
||||||
{
|
|
||||||
"too long",
|
|
||||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", maxPasswordLength+1)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"low entropy",
|
|
||||||
fmt.Sprintf(`{"username": "valid", "password": "%s"}`, strings.Repeat("a", minPasswordLength)),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
@ -77,7 +77,7 @@ func TestIsPasswordCompromised(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("randomly generated", func(t *testing.T) {
|
t.Run("randomly generated", func(t *testing.T) {
|
||||||
randomStr := genRandomString(t, 12)
|
randomStr := genRandomString(12)
|
||||||
compromised, err := isPasswordCompromised(randomStr)
|
compromised, err := isPasswordCompromised(randomStr)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.False(t, compromised)
|
assert.False(t, compromised)
|
||||||
@ -173,9 +173,7 @@ func TestNormalizeUsername(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRandomString(t *testing.T, length int) string {
|
func genRandomString(length int) string {
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
|
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
b := make([]byte, length)
|
b := make([]byte, length)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user