package service import ( "context" "errors" "fmt" "net/http" "net/http/httptest" "testing" "time" "git.umbrella.haus/ae/notatest/pkg/data" "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/assert" ) func TestAuthMiddleware(t *testing.T) { secret := "test-secret" validToken := generateTestToken(t, secret, "access", uuid.New().String(), true) expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) { claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)) }) tests := []struct { name string token string expectedErr string statusCode int }{ { "no token", "", "Unauthorized", http.StatusUnauthorized, }, { "invalid token", "invalid", "Invalid token", http.StatusUnauthorized, }, { "expired token", expiredToken, "Invalid token", http.StatusUnauthorized, }, { "wrong type", generateTestToken( t, secret, "refresh", uuid.New().String(), true, ), "Invalid token type", http.StatusUnauthorized, }, { "valid token", validToken, "", http.StatusOK, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mw := requireAccessToken(secret) req := httptest.NewRequest("GET", "/", nil) if tc.token != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token)) } w := httptest.NewRecorder() called := false handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true _, ok := r.Context().Value(userCtxKey{}).(*userClaims) assert.True(t, ok) })) handler.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) if tc.expectedErr != "" { assert.Contains(t, w.Body.String(), tc.expectedErr) } assert.Equal(t, tc.statusCode == http.StatusOK, called) }) } } func TestAdminOnlyMiddleware(t *testing.T) { tests := []struct { name string user *userClaims statusCode int }{ { "no user", nil, http.StatusForbidden, }, { "non admin user", &userClaims{ Admin: false, }, http.StatusForbidden, }, { "admin user", &userClaims{ Admin: true, }, http.StatusOK, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mw := adminOnlyMiddleware req := httptest.NewRequest("GET", "/", nil) if tc.user != nil { req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user)) } w := httptest.NewRecorder() called := false handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true w.WriteHeader(http.StatusOK) })) handler.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) assert.Equal(t, tc.statusCode == http.StatusOK, called) }) } } func TestOwnerOnlyMiddleware(t *testing.T) { userID := uuid.New().String() tests := []struct { name string user *userClaims urlID string statusCode int }{ { "no user", nil, userID, http.StatusForbidden, }, { "different ID", &userClaims{ RegisteredClaims: jwt.RegisteredClaims{ Subject: uuid.New().String(), }}, userID, http.StatusForbidden, }, { "matching ID", &userClaims{ RegisteredClaims: jwt.RegisteredClaims{ Subject: userID, }, }, userID, http.StatusOK, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { r := chi.NewRouter() handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) r.With( // Add user with the given claims to request's context func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ctx context.Context = r.Context() if tc.user != nil { ctx = context.WithValue(ctx, userCtxKey{}, tc.user) } next.ServeHTTP(w, r.WithContext(ctx)) }) }, ownerOnlyMiddleware, ).Get("/{id}", handlerChain) req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) if tc.urlID == "invalid" { assert.Equal(t, http.StatusNotFound, w.Code) } else { assert.Equal(t, tc.statusCode, w.Code) } }) } } func TestUserCtxMiddleware(t *testing.T) { validUserID := uuid.New() invalidUserID := "invalid" mockStore := &mockUserStore{ GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) { if id == validUserID { return data.User{ID: validUserID}, nil } return data.User{}, errors.New("not found") }, } tests := []struct { name string urlID string statusCode int }{ {"valid ID", validUserID.String(), http.StatusOK}, {"invalid ID", invalidUserID, http.StatusNotFound}, {"non existent ID", uuid.New().String(), http.StatusNotFound}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mw := userCtx(mockStore) r := chi.NewRouter() r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) { user, ok := r.Context().Value(userCtxKey{}).(data.User) assert.True(t, ok) assert.Equal(t, validUserID, user.ID) w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, tc.statusCode, w.Code) }) } } func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string { t.Helper() claims := &userClaims{ Admin: isAdmin, TokenType: tokenType, RegisteredClaims: jwt.RegisteredClaims{ Subject: userID, ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, } for _, opt := range opts { opt(claims) } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signedToken, err := token.SignedString([]byte(secret)) if err != nil { t.Fatalf("Failed to generate test token: %v", err) } return signedToken }