github.com/oinume/lekcije@v0.0.0-20231017100347-5b4c5eb6ab24/backend/infrastructure/mysql/following_teacher_test.go (about) 1 package mysql_test 2 3 import ( 4 "context" 5 "testing" 6 "time" 7 8 "github.com/oinume/lekcije/backend/infrastructure/mysql" 9 "github.com/oinume/lekcije/backend/internal/assertion" 10 "github.com/oinume/lekcije/backend/internal/modeltest" 11 "github.com/oinume/lekcije/backend/internal/mysqltest" 12 "github.com/oinume/lekcije/backend/internal/slice_util" 13 "github.com/oinume/lekcije/backend/model2" 14 ) 15 16 func Test_followingTeacherRepository_FindTeacherIDsByUserID(t *testing.T) { 17 repo := mysql.NewFollowingTeacherRepository(helper.DB(t).DB()) 18 repos := mysqltest.NewRepositories(helper.DB(t).DB()) 19 20 type testCase struct { 21 userID uint 22 lastLessonAt time.Time 23 wantTeacherIDs []uint 24 } 25 26 tests := map[string]struct { 27 setup func(ctx context.Context) *testCase 28 wantErr bool 29 }{ 30 "normal": { 31 setup: func(ctx context.Context) *testCase { 32 helper.TruncateAllTables(t) 33 //boil.DebugMode = true 34 //boil.DebugWriter = os.Stdout 35 36 user := modeltest.NewUser() 37 repos.CreateUsers(ctx, t, user) 38 39 now := time.Now().UTC() 40 teacher1 := modeltest.NewTeacher(func(t *model2.Teacher) { 41 t.LastLessonAt = now.Add(1 * time.Hour) 42 }) 43 ft1 := modeltest.NewFollowingTeacher(func(ft *model2.FollowingTeacher) { 44 ft.UserID = user.ID 45 ft.TeacherID = teacher1.ID 46 }) 47 teacher2 := modeltest.NewTeacher(func(t *model2.Teacher) { 48 t.LastLessonAt = now.Add(24 * time.Hour) 49 }) 50 ft2 := modeltest.NewFollowingTeacher(func(ft *model2.FollowingTeacher) { 51 ft.UserID = user.ID 52 ft.TeacherID = teacher2.ID 53 }) 54 repos.CreateTeachers(ctx, t, teacher1, teacher2) 55 repos.CreateFollowingTeachers(ctx, t, ft1, ft2) 56 57 teacherIDs := []uint{teacher1.ID, teacher2.ID} 58 slice_util.Sort(teacherIDs) 59 return &testCase{ 60 userID: user.ID, 61 lastLessonAt: now, 62 wantTeacherIDs: teacherIDs, 63 } 64 }, 65 }, 66 } 67 68 for name, tt := range tests { 69 t.Run(name, func(t *testing.T) { 70 ctx := context.Background() 71 tc := tt.setup(ctx) 72 gotTeacherIDs, err := repo.FindTeacherIDsByUserID(ctx, tc.userID, 5, tc.lastLessonAt) 73 if (err != nil) != tt.wantErr { 74 t.Errorf("FindTeacherIDsByUserID() error = %v, wantErr %v", err, tt.wantErr) 75 return 76 } 77 78 slice_util.Sort(gotTeacherIDs) 79 assertion.AssertEqual(t, tc.wantTeacherIDs, gotTeacherIDs, "") 80 }) 81 } 82 }