From 993b576d0d0e13722a6773d86017acb21b54dcae Mon Sep 17 00:00:00 2001
From: ae <git@golfed.xyz>
Date: Tue, 1 Apr 2025 12:20:00 +0300
Subject: [PATCH] test: middleware unit tests

---
 server/pkg/service/middleware_test.go | 283 ++++++++++++++++++++++++++
 1 file changed, 283 insertions(+)
 create mode 100644 server/pkg/service/middleware_test.go

diff --git a/server/pkg/service/middleware_test.go b/server/pkg/service/middleware_test.go
new file mode 100644
index 0000000..e9c4603
--- /dev/null
+++ b/server/pkg/service/middleware_test.go
@@ -0,0 +1,283 @@
+package service
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	"git.umbrella.haus/ae/notatest/pkg/data"
+	"github.com/go-chi/chi/v5"
+	"github.com/golang-jwt/jwt/v5"
+	"github.com/google/uuid"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestAuthMiddleware(t *testing.T) {
+	secret := "test-secret"
+	validToken := generateTestToken(t, secret, "access", uuid.New().String(), true)
+	expiredToken := generateTestToken(t, secret, "access", uuid.New().String(), true, func(claims *userClaims) {
+		claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
+	})
+
+	tests := []struct {
+		name        string
+		token       string
+		expectedErr string
+		statusCode  int
+	}{
+		{
+			"no token",
+			"",
+			"Unauthorized",
+			http.StatusUnauthorized,
+		},
+		{
+			"invalid token",
+			"invalid",
+			"Invalid token",
+			http.StatusUnauthorized,
+		},
+		{
+			"expired token",
+			expiredToken,
+			"Invalid token",
+			http.StatusUnauthorized,
+		},
+		{
+			"wrong type",
+			generateTestToken(
+				t,
+				secret,
+				"refresh",
+				uuid.New().String(),
+				true,
+			),
+			"Invalid token type",
+			http.StatusUnauthorized,
+		},
+		{
+			"valid token",
+			validToken,
+			"",
+			http.StatusOK,
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			mw := requireAccessToken(secret)
+			req := httptest.NewRequest("GET", "/", nil)
+			if tc.token != "" {
+				req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tc.token))
+			}
+
+			w := httptest.NewRecorder()
+			called := false
+			handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				called = true
+				_, ok := r.Context().Value(userCtxKey{}).(*userClaims)
+				assert.True(t, ok)
+			}))
+
+			handler.ServeHTTP(w, req)
+
+			assert.Equal(t, tc.statusCode, w.Code)
+			if tc.expectedErr != "" {
+				assert.Contains(t, w.Body.String(), tc.expectedErr)
+			}
+			assert.Equal(t, tc.statusCode == http.StatusOK, called)
+		})
+	}
+}
+
+func TestAdminOnlyMiddleware(t *testing.T) {
+	tests := []struct {
+		name       string
+		user       *userClaims
+		statusCode int
+	}{
+		{
+			"no user",
+			nil,
+			http.StatusForbidden,
+		},
+		{
+			"non admin user",
+			&userClaims{
+				Admin: false,
+			},
+			http.StatusForbidden,
+		},
+		{
+			"admin user",
+			&userClaims{
+				Admin: true,
+			},
+			http.StatusOK,
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			mw := adminOnlyMiddleware
+			req := httptest.NewRequest("GET", "/", nil)
+			if tc.user != nil {
+				req = req.WithContext(context.WithValue(req.Context(), userCtxKey{}, tc.user))
+			}
+
+			w := httptest.NewRecorder()
+			called := false
+			handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				called = true
+				w.WriteHeader(http.StatusOK)
+			}))
+
+			handler.ServeHTTP(w, req)
+
+			assert.Equal(t, tc.statusCode, w.Code)
+			assert.Equal(t, tc.statusCode == http.StatusOK, called)
+		})
+	}
+}
+
+func TestOwnerOnlyMiddleware(t *testing.T) {
+	userID := uuid.New().String()
+	tests := []struct {
+		name       string
+		user       *userClaims
+		urlID      string
+		statusCode int
+	}{
+		{
+			"no user",
+			nil,
+			userID,
+			http.StatusForbidden,
+		},
+		{
+			"different ID",
+			&userClaims{
+				RegisteredClaims: jwt.RegisteredClaims{
+					Subject: uuid.New().String(),
+				}},
+			userID,
+			http.StatusForbidden,
+		},
+		{
+			"matching ID",
+			&userClaims{
+				RegisteredClaims: jwt.RegisteredClaims{
+					Subject: userID,
+				},
+			},
+			userID,
+			http.StatusOK,
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			r := chi.NewRouter()
+
+			handlerChain := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				w.WriteHeader(http.StatusOK)
+			})
+
+			r.With(
+				// Add user with the given claims to request's context
+				func(next http.Handler) http.Handler {
+					return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+						var ctx context.Context = r.Context()
+						if tc.user != nil {
+							ctx = context.WithValue(ctx, userCtxKey{}, tc.user)
+						}
+						next.ServeHTTP(w, r.WithContext(ctx))
+					})
+				},
+				ownerOnlyMiddleware,
+			).Get("/{id}", handlerChain)
+
+			req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
+			w := httptest.NewRecorder()
+			r.ServeHTTP(w, req)
+
+			if tc.urlID == "invalid" {
+				assert.Equal(t, http.StatusNotFound, w.Code)
+			} else {
+				assert.Equal(t, tc.statusCode, w.Code)
+			}
+		})
+	}
+}
+
+func TestUserCtxMiddleware(t *testing.T) {
+	validUserID := uuid.New()
+	invalidUserID := "invalid"
+
+	mockStore := &mockUserStore{
+		GetUserByIDFunc: func(ctx context.Context, id uuid.UUID) (data.User, error) {
+			if id == validUserID {
+				return data.User{ID: validUserID}, nil
+			}
+			return data.User{}, errors.New("not found")
+		},
+	}
+
+	tests := []struct {
+		name       string
+		urlID      string
+		statusCode int
+	}{
+		{"valid ID", validUserID.String(), http.StatusOK},
+		{"invalid ID", invalidUserID, http.StatusNotFound},
+		{"non existent ID", uuid.New().String(), http.StatusNotFound},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			mw := userCtx(mockStore)
+			r := chi.NewRouter()
+			r.With(mw).Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
+				user, ok := r.Context().Value(userCtxKey{}).(data.User)
+				assert.True(t, ok)
+				assert.Equal(t, validUserID, user.ID)
+				w.WriteHeader(http.StatusOK)
+			})
+
+			req := httptest.NewRequest("GET", fmt.Sprintf("/%s", tc.urlID), nil)
+			w := httptest.NewRecorder()
+			r.ServeHTTP(w, req)
+
+			assert.Equal(t, tc.statusCode, w.Code)
+		})
+	}
+}
+
+func generateTestToken(t *testing.T, secret, tokenType, userID string, isAdmin bool, opts ...func(*userClaims)) string {
+	t.Helper()
+
+	claims := &userClaims{
+		Admin:     isAdmin,
+		TokenType: tokenType,
+		RegisteredClaims: jwt.RegisteredClaims{
+			Subject:   userID,
+			ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+		},
+	}
+
+	for _, opt := range opts {
+		opt(claims)
+	}
+
+	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+	signedToken, err := token.SignedString([]byte(secret))
+	if err != nil {
+		t.Fatalf("Failed to generate test token: %v", err)
+	}
+
+	return signedToken
+}