diff --git a/server/pkg/service/middleware_test.go b/server/pkg/service/middleware_test.go new file mode 100644 index 0000000..e9c4603 --- /dev/null +++ b/server/pkg/service/middleware_test.go @@ -0,0 +1,283 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.umbrella.haus/ae/notatest/pkg/data" + "github.com/go-chi/chi/v5" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestAuthMiddleware(t *testing.T) { + secret := "test-secret" + validToken := generateTestToken(t, secret, "access", uuid.New().String(), true) + expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) { + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)) + }) + + tests := []struct { + name string + token string + expectedErr string + statusCode int + }{ + { + "no token", + "", + "Unauthorized", + http.StatusUnauthorized, + }, + { + "invalid token", + "invalid", + "Invalid token", + http.StatusUnauthorized, + }, + { + "expired token", + expiredToken, + "Invalid token", + http.StatusUnauthorized, + }, + { + "wrong type", + generateTestToken( + t, + secret, + "refresh", + uuid.New().String(), + true, + ), + "Invalid token type", + http.StatusUnauthorized, + }, + { + "valid token", + validToken, + "", + http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mw := requireAccessToken(secret) + req := httptest.NewRequest("GET", "/", nil) + if tc.token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token)) + } + + w := httptest.NewRecorder() + called := false + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + _, ok := r.Context().Value(userCtxKey{}).(*userClaims) + assert.True(t, ok) + })) + + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + if tc.expectedErr != "" { + assert.Contains(t, w.Body.String(), tc.expectedErr) + } + assert.Equal(t, tc.statusCode == http.StatusOK, called) + }) + } +} + +func TestAdminOnlyMiddleware(t *testing.T) { + tests := []struct { + name string + user *userClaims + statusCode int + }{ + { + "no user", + nil, + http.StatusForbidden, + }, + { + "non admin user", + &userClaims{ + Admin: false, + }, + http.StatusForbidden, + }, + { + "admin user", + &userClaims{ + Admin: true, + }, + http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mw := adminOnlyMiddleware + req := httptest.NewRequest("GET", "/", nil) + if tc.user != nil { + req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user)) + } + + w := httptest.NewRecorder() + called := false + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + assert.Equal(t, tc.statusCode == http.StatusOK, called) + }) + } +} + +func TestOwnerOnlyMiddleware(t *testing.T) { + userID := uuid.New().String() + tests := []struct { + name string + user *userClaims + urlID string + statusCode int + }{ + { + "no user", + nil, + userID, + http.StatusForbidden, + }, + { + "different ID", + &userClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: uuid.New().String(), + }}, + userID, + http.StatusForbidden, + }, + { + "matching ID", + &userClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + }, + }, + userID, + http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := chi.NewRouter() + + handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + r.With( + // Add user with the given claims to request's context + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ctx context.Context = r.Context() + if tc.user != nil { + ctx = context.WithValue(ctx, userCtxKey{}, tc.user) + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }, + ownerOnlyMiddleware, + ).Get("/{id}", handlerChain) + + req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if tc.urlID == "invalid" { + assert.Equal(t, http.StatusNotFound, w.Code) + } else { + assert.Equal(t, tc.statusCode, w.Code) + } + }) + } +} + +func TestUserCtxMiddleware(t *testing.T) { + validUserID := uuid.New() + invalidUserID := "invalid" + + mockStore := &mockUserStore{ + GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) { + if id == validUserID { + return data.User{ID: validUserID}, nil + } + return data.User{}, errors.New("not found") + }, + } + + tests := []struct { + name string + urlID string + statusCode int + }{ + {"valid ID", validUserID.String(), http.StatusOK}, + {"invalid ID", invalidUserID, http.StatusNotFound}, + {"non existent ID", uuid.New().String(), http.StatusNotFound}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mw := userCtx(mockStore) + r := chi.NewRouter() + r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(userCtxKey{}).(data.User) + assert.True(t, ok) + assert.Equal(t, validUserID, user.ID) + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + assert.Equal(t, tc.statusCode, w.Code) + }) + } +} + +func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string { + t.Helper() + + claims := &userClaims{ + Admin: isAdmin, + TokenType: tokenType, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + }, + } + + for _, opt := range opts { + opt(claims) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedToken, err := token.SignedString([]byte(secret)) + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + return signedToken +}