github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/interceptor/user_auth_test.go (about) 1 package interceptor_test 2 3 import ( 4 "context" 5 "testing" 6 7 "connectrpc.com/connect" 8 "github.com/google/go-cmp/cmp" 9 "github.com/quickfeed/quickfeed/internal/qtest" 10 "github.com/quickfeed/quickfeed/qf" 11 "github.com/quickfeed/quickfeed/web" 12 "github.com/quickfeed/quickfeed/web/auth" 13 "github.com/quickfeed/quickfeed/web/interceptor" 14 "google.golang.org/protobuf/testing/protocmp" 15 ) 16 17 func TestUserVerifier(t *testing.T) { 18 db, cleanup := qtest.TestDB(t) 19 defer cleanup() 20 logger := qtest.Logger(t) 21 22 tm, err := auth.NewTokenManager(db) 23 if err != nil { 24 t.Fatal(err) 25 } 26 client := web.MockClient(t, db, connect.WithInterceptors( 27 interceptor.NewUserInterceptor(logger, tm), 28 )) 29 ctx := context.Background() 30 31 adminUser := qtest.CreateFakeUser(t, db) 32 student := qtest.CreateFakeUser(t, db) 33 34 adminCookie, err := tm.NewAuthCookie(adminUser.ID) 35 if err != nil { 36 t.Fatal(err) 37 } 38 studentCookie, err := tm.NewAuthCookie(student.ID) 39 if err != nil { 40 t.Fatal(err) 41 } 42 43 userTest := []struct { 44 code connect.Code 45 cookie string 46 wantUser *qf.User 47 }{ 48 {code: connect.CodeUnauthenticated, cookie: "", wantUser: nil}, 49 {code: connect.CodeUnauthenticated, cookie: "should fail", wantUser: nil}, 50 {code: 0, cookie: adminCookie.String(), wantUser: adminUser}, 51 {code: 0, cookie: studentCookie.String(), wantUser: student}, 52 } 53 54 for _, user := range userTest { 55 gotUser, err := client.GetUser(ctx, qtest.RequestWithCookie(&qf.Void{}, user.cookie)) 56 if err != nil { 57 // zero codes won't actually reach this check, but that's okay, since zero is CodeOK 58 if gotCode := connect.CodeOf(err); gotCode != user.code { 59 t.Errorf("GetUser() = %v, want %v", gotCode, user.code) 60 } 61 } 62 wantUser := user.wantUser 63 if gotUser == nil { 64 if wantUser != nil { 65 t.Errorf("GetUser(): %v, want: %v", gotUser, wantUser) 66 } 67 } else { 68 if diff := cmp.Diff(wantUser, gotUser.Msg, protocmp.Transform()); diff != "" { 69 t.Errorf("GetUser() mismatch (-wantUser +gotUser):\n%s", diff) 70 } 71 } 72 } 73 }