github.com/hoveychen/kafka-go@v0.4.42/consumergroup_test.go (about)

     1  package kafka
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"reflect"
     7  	"strings"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  )
    12  
    13  var _ coordinator = mockCoordinator{}
    14  
    15  type mockCoordinator struct {
    16  	closeFunc           func() error
    17  	findCoordinatorFunc func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error)
    18  	joinGroupFunc       func(joinGroupRequestV1) (joinGroupResponseV1, error)
    19  	syncGroupFunc       func(syncGroupRequestV0) (syncGroupResponseV0, error)
    20  	leaveGroupFunc      func(leaveGroupRequestV0) (leaveGroupResponseV0, error)
    21  	heartbeatFunc       func(heartbeatRequestV0) (heartbeatResponseV0, error)
    22  	offsetFetchFunc     func(offsetFetchRequestV1) (offsetFetchResponseV1, error)
    23  	offsetCommitFunc    func(offsetCommitRequestV2) (offsetCommitResponseV2, error)
    24  	readPartitionsFunc  func(...string) ([]Partition, error)
    25  }
    26  
    27  func (c mockCoordinator) Close() error {
    28  	if c.closeFunc != nil {
    29  		return c.closeFunc()
    30  	}
    31  	return nil
    32  }
    33  
    34  func (c mockCoordinator) findCoordinator(req findCoordinatorRequestV0) (findCoordinatorResponseV0, error) {
    35  	if c.findCoordinatorFunc == nil {
    36  		return findCoordinatorResponseV0{}, errors.New("no findCoordinator behavior specified")
    37  	}
    38  	return c.findCoordinatorFunc(req)
    39  }
    40  
    41  func (c mockCoordinator) joinGroup(req joinGroupRequestV1) (joinGroupResponseV1, error) {
    42  	if c.joinGroupFunc == nil {
    43  		return joinGroupResponseV1{}, errors.New("no joinGroup behavior specified")
    44  	}
    45  	return c.joinGroupFunc(req)
    46  }
    47  
    48  func (c mockCoordinator) syncGroup(req syncGroupRequestV0) (syncGroupResponseV0, error) {
    49  	if c.syncGroupFunc == nil {
    50  		return syncGroupResponseV0{}, errors.New("no syncGroup behavior specified")
    51  	}
    52  	return c.syncGroupFunc(req)
    53  }
    54  
    55  func (c mockCoordinator) leaveGroup(req leaveGroupRequestV0) (leaveGroupResponseV0, error) {
    56  	if c.leaveGroupFunc == nil {
    57  		return leaveGroupResponseV0{}, errors.New("no leaveGroup behavior specified")
    58  	}
    59  	return c.leaveGroupFunc(req)
    60  }
    61  
    62  func (c mockCoordinator) heartbeat(req heartbeatRequestV0) (heartbeatResponseV0, error) {
    63  	if c.heartbeatFunc == nil {
    64  		return heartbeatResponseV0{}, errors.New("no heartbeat behavior specified")
    65  	}
    66  	return c.heartbeatFunc(req)
    67  }
    68  
    69  func (c mockCoordinator) offsetFetch(req offsetFetchRequestV1) (offsetFetchResponseV1, error) {
    70  	if c.offsetFetchFunc == nil {
    71  		return offsetFetchResponseV1{}, errors.New("no offsetFetch behavior specified")
    72  	}
    73  	return c.offsetFetchFunc(req)
    74  }
    75  
    76  func (c mockCoordinator) offsetCommit(req offsetCommitRequestV2) (offsetCommitResponseV2, error) {
    77  	if c.offsetCommitFunc == nil {
    78  		return offsetCommitResponseV2{}, errors.New("no offsetCommit behavior specified")
    79  	}
    80  	return c.offsetCommitFunc(req)
    81  }
    82  
    83  func (c mockCoordinator) readPartitions(topics ...string) ([]Partition, error) {
    84  	if c.readPartitionsFunc == nil {
    85  		return nil, errors.New("no Readpartitions behavior specified")
    86  	}
    87  	return c.readPartitionsFunc(topics...)
    88  }
    89  
    90  func TestValidateConsumerGroupConfig(t *testing.T) {
    91  	tests := []struct {
    92  		config       ConsumerGroupConfig
    93  		errorOccured bool
    94  	}{
    95  		{config: ConsumerGroupConfig{}, errorOccured: true},
    96  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, HeartbeatInterval: 2}, errorOccured: true},
    97  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}}, errorOccured: true},
    98  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: -1}, errorOccured: true},
    99  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", SessionTimeout: -1}, errorOccured: true},
   100  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: -1}, errorOccured: true},
   101  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: -2}, errorOccured: true},
   102  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: -1}, errorOccured: true},
   103  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, StartOffset: 123}, errorOccured: true},
   104  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: -1}, errorOccured: true},
   105  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: 1, JoinGroupBackoff: -1}, errorOccured: true},
   106  		{config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: 1, JoinGroupBackoff: 1}, errorOccured: false},
   107  	}
   108  	for _, test := range tests {
   109  		err := test.config.Validate()
   110  		if test.errorOccured && err == nil {
   111  			t.Error("expected an error", test.config)
   112  		}
   113  		if !test.errorOccured && err != nil {
   114  			t.Error("expected no error, got", err, test.config)
   115  		}
   116  	}
   117  }
   118  
   119  func TestReaderAssignTopicPartitions(t *testing.T) {
   120  	conn := &mockCoordinator{
   121  		readPartitionsFunc: func(...string) ([]Partition, error) {
   122  			return []Partition{
   123  				{
   124  					Topic: "topic-1",
   125  					ID:    0,
   126  				},
   127  				{
   128  					Topic: "topic-1",
   129  					ID:    1,
   130  				},
   131  				{
   132  					Topic: "topic-1",
   133  					ID:    2,
   134  				},
   135  				{
   136  					Topic: "topic-2",
   137  					ID:    0,
   138  				},
   139  			}, nil
   140  		},
   141  	}
   142  
   143  	newJoinGroupResponseV1 := func(topicsByMemberID map[string][]string) joinGroupResponseV1 {
   144  		resp := joinGroupResponseV1{
   145  			GroupProtocol: RoundRobinGroupBalancer{}.ProtocolName(),
   146  		}
   147  
   148  		for memberID, topics := range topicsByMemberID {
   149  			resp.Members = append(resp.Members, joinGroupResponseMemberV1{
   150  				MemberID: memberID,
   151  				MemberMetadata: groupMetadata{
   152  					Topics: topics,
   153  				}.bytes(),
   154  			})
   155  		}
   156  
   157  		return resp
   158  	}
   159  
   160  	testCases := map[string]struct {
   161  		Members     joinGroupResponseV1
   162  		Assignments GroupMemberAssignments
   163  	}{
   164  		"nil": {
   165  			Members:     newJoinGroupResponseV1(nil),
   166  			Assignments: GroupMemberAssignments{},
   167  		},
   168  		"one member, one topic": {
   169  			Members: newJoinGroupResponseV1(map[string][]string{
   170  				"member-1": {"topic-1"},
   171  			}),
   172  			Assignments: GroupMemberAssignments{
   173  				"member-1": map[string][]int{
   174  					"topic-1": {0, 1, 2},
   175  				},
   176  			},
   177  		},
   178  		"one member, two topics": {
   179  			Members: newJoinGroupResponseV1(map[string][]string{
   180  				"member-1": {"topic-1", "topic-2"},
   181  			}),
   182  			Assignments: GroupMemberAssignments{
   183  				"member-1": map[string][]int{
   184  					"topic-1": {0, 1, 2},
   185  					"topic-2": {0},
   186  				},
   187  			},
   188  		},
   189  		"two members, one topic": {
   190  			Members: newJoinGroupResponseV1(map[string][]string{
   191  				"member-1": {"topic-1"},
   192  				"member-2": {"topic-1"},
   193  			}),
   194  			Assignments: GroupMemberAssignments{
   195  				"member-1": map[string][]int{
   196  					"topic-1": {0, 2},
   197  				},
   198  				"member-2": map[string][]int{
   199  					"topic-1": {1},
   200  				},
   201  			},
   202  		},
   203  		"two members, two unshared topics": {
   204  			Members: newJoinGroupResponseV1(map[string][]string{
   205  				"member-1": {"topic-1"},
   206  				"member-2": {"topic-2"},
   207  			}),
   208  			Assignments: GroupMemberAssignments{
   209  				"member-1": map[string][]int{
   210  					"topic-1": {0, 1, 2},
   211  				},
   212  				"member-2": map[string][]int{
   213  					"topic-2": {0},
   214  				},
   215  			},
   216  		},
   217  	}
   218  
   219  	for label, tc := range testCases {
   220  		t.Run(label, func(t *testing.T) {
   221  			cg := ConsumerGroup{}
   222  			cg.config.GroupBalancers = []GroupBalancer{
   223  				RangeGroupBalancer{},
   224  				RoundRobinGroupBalancer{},
   225  			}
   226  			assignments, err := cg.assignTopicPartitions(conn, tc.Members)
   227  			if err != nil {
   228  				t.Fatalf("bad err: %v", err)
   229  			}
   230  			if !reflect.DeepEqual(tc.Assignments, assignments) {
   231  				t.Errorf("expected %v; got %v", tc.Assignments, assignments)
   232  			}
   233  		})
   234  	}
   235  }
   236  
   237  func TestConsumerGroup(t *testing.T) {
   238  	tests := []struct {
   239  		scenario string
   240  		function func(*testing.T, context.Context, *ConsumerGroup)
   241  	}{
   242  		{
   243  			scenario: "Next returns generations",
   244  			function: func(t *testing.T, ctx context.Context, cg *ConsumerGroup) {
   245  				gen1, err := cg.Next(ctx)
   246  				if gen1 == nil {
   247  					t.Fatalf("expected generation 1 not to be nil")
   248  				}
   249  				if err != nil {
   250  					t.Fatalf("expected no error, but got %+v", err)
   251  				}
   252  				// returning from this function should cause the generation to
   253  				// exit.
   254  				gen1.Start(func(context.Context) {})
   255  
   256  				// if this fails due to context timeout, it would indicate that
   257  				// the
   258  				gen2, err := cg.Next(ctx)
   259  				if gen2 == nil {
   260  					t.Fatalf("expected generation 2 not to be nil")
   261  				}
   262  				if err != nil {
   263  					t.Fatalf("expected no error, but got %+v", err)
   264  				}
   265  
   266  				if gen1.ID == gen2.ID {
   267  					t.Errorf("generation ID should have changed, but it stayed as %d", gen1.ID)
   268  				}
   269  				if gen1.GroupID != gen2.GroupID {
   270  					t.Errorf("mismatched group ID between generations: %s and %s", gen1.GroupID, gen2.GroupID)
   271  				}
   272  				if gen1.MemberID != gen2.MemberID {
   273  					t.Errorf("mismatched member ID between generations: %s and %s", gen1.MemberID, gen2.MemberID)
   274  				}
   275  			},
   276  		},
   277  
   278  		{
   279  			scenario: "Next returns ctx.Err() on canceled context",
   280  			function: func(t *testing.T, _ context.Context, cg *ConsumerGroup) {
   281  				ctx, cancel := context.WithCancel(context.Background())
   282  				cancel()
   283  
   284  				gen, err := cg.Next(ctx)
   285  				if gen != nil {
   286  					t.Errorf("expected generation to be nil")
   287  				}
   288  				if !errors.Is(err, context.Canceled) {
   289  					t.Errorf("expected context.Canceled, but got %+v", err)
   290  				}
   291  			},
   292  		},
   293  
   294  		{
   295  			scenario: "Next returns ErrGroupClosed on closed group",
   296  			function: func(t *testing.T, ctx context.Context, cg *ConsumerGroup) {
   297  				if err := cg.Close(); err != nil {
   298  					t.Fatal(err)
   299  				}
   300  				gen, err := cg.Next(ctx)
   301  				if gen != nil {
   302  					t.Errorf("expected generation to be nil")
   303  				}
   304  				if !errors.Is(err, ErrGroupClosed) {
   305  					t.Errorf("expected ErrGroupClosed, but got %+v", err)
   306  				}
   307  			},
   308  		},
   309  	}
   310  
   311  	topic := makeTopic()
   312  	createTopic(t, topic, 1)
   313  	defer deleteTopic(t, topic)
   314  
   315  	for _, test := range tests {
   316  		t.Run(test.scenario, func(t *testing.T) {
   317  			group, err := NewConsumerGroup(ConsumerGroupConfig{
   318  				ID:                makeGroupID(),
   319  				Topics:            []string{topic},
   320  				Brokers:           []string{"localhost:9092"},
   321  				HeartbeatInterval: 2 * time.Second,
   322  				RebalanceTimeout:  2 * time.Second,
   323  				RetentionTime:     time.Hour,
   324  				Logger:            &testKafkaLogger{T: t},
   325  			})
   326  			if err != nil {
   327  				t.Fatal(err)
   328  			}
   329  			defer group.Close()
   330  
   331  			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
   332  			defer cancel()
   333  
   334  			test.function(t, ctx, group)
   335  		})
   336  	}
   337  }
   338  
   339  func TestConsumerGroupErrors(t *testing.T) {
   340  	var left []string
   341  	var lock sync.Mutex
   342  	mc := mockCoordinator{
   343  		leaveGroupFunc: func(req leaveGroupRequestV0) (leaveGroupResponseV0, error) {
   344  			lock.Lock()
   345  			left = append(left, req.MemberID)
   346  			lock.Unlock()
   347  			return leaveGroupResponseV0{}, nil
   348  		},
   349  	}
   350  	assertLeftGroup := func(t *testing.T, memberID string) {
   351  		lock.Lock()
   352  		if !reflect.DeepEqual(left, []string{memberID}) {
   353  			t.Errorf("expected abc to have left group once, members left: %v", left)
   354  		}
   355  		left = left[0:0]
   356  		lock.Unlock()
   357  	}
   358  
   359  	// NOTE : the mocked behavior is accumulated across the tests, so they are
   360  	// 		  NOT run in parallel.  this simplifies test setup so that each test
   361  	// 	 	  can specify only the error behavior required and leverage setup
   362  	//        from previous steps.
   363  	tests := []struct {
   364  		scenario string
   365  		prepare  func(*mockCoordinator)
   366  		function func(*testing.T, context.Context, *ConsumerGroup)
   367  	}{
   368  		{
   369  			scenario: "fails to find coordinator (general error)",
   370  			prepare: func(mc *mockCoordinator) {
   371  				mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) {
   372  					return findCoordinatorResponseV0{}, errors.New("dial error")
   373  				}
   374  			},
   375  			function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) {
   376  				gen, err := group.Next(ctx)
   377  				if err == nil {
   378  					t.Errorf("expected an error")
   379  				} else if err.Error() != "dial error" {
   380  					t.Errorf("got wrong error: %+v", err)
   381  				}
   382  				if gen != nil {
   383  					t.Error("expected a nil consumer group generation")
   384  				}
   385  			},
   386  		},
   387  
   388  		{
   389  			scenario: "fails to find coordinator (error code in response)",
   390  			prepare: func(mc *mockCoordinator) {
   391  				mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) {
   392  					return findCoordinatorResponseV0{
   393  						ErrorCode: int16(NotCoordinatorForGroup),
   394  					}, nil
   395  				}
   396  			},
   397  			function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) {
   398  				gen, err := group.Next(ctx)
   399  				if err == nil {
   400  					t.Errorf("expected an error")
   401  				} else if !errors.Is(err, NotCoordinatorForGroup) {
   402  					t.Errorf("got wrong error: %+v", err)
   403  				}
   404  				if gen != nil {
   405  					t.Error("expected a nil consumer group generation")
   406  				}
   407  			},
   408  		},
   409  
   410  		{
   411  			scenario: "fails to join group (general error)",
   412  			prepare: func(mc *mockCoordinator) {
   413  				mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) {
   414  					return findCoordinatorResponseV0{
   415  						Coordinator: findCoordinatorResponseCoordinatorV0{
   416  							NodeID: 1,
   417  							Host:   "foo.bar.com",
   418  							Port:   12345,
   419  						},
   420  					}, nil
   421  				}
   422  				mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) {
   423  					return joinGroupResponseV1{}, errors.New("join group failed")
   424  				}
   425  				// NOTE : no stub for leaving the group b/c the member never joined.
   426  			},
   427  			function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) {
   428  				gen, err := group.Next(ctx)
   429  				if err == nil {
   430  					t.Errorf("expected an error")
   431  				} else if err.Error() != "join group failed" {
   432  					t.Errorf("got wrong error: %+v", err)
   433  				}
   434  				if gen != nil {
   435  					t.Error("expected a nil consumer group generation")
   436  				}
   437  			},
   438  		},
   439  
   440  		{
   441  			scenario: "fails to join group (error code)",
   442  			prepare: func(mc *mockCoordinator) {
   443  				mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) {
   444  					return findCoordinatorResponseV0{
   445  						Coordinator: findCoordinatorResponseCoordinatorV0{
   446  							NodeID: 1,
   447  							Host:   "foo.bar.com",
   448  							Port:   12345,
   449  						},
   450  					}, nil
   451  				}
   452  				mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) {
   453  					return joinGroupResponseV1{
   454  						ErrorCode: int16(InvalidTopic),
   455  					}, nil
   456  				}
   457  				// NOTE : no stub for leaving the group b/c the member never joined.
   458  			},
   459  			function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) {
   460  				gen, err := group.Next(ctx)
   461  				if err == nil {
   462  					t.Errorf("expected an error")
   463  				} else if !errors.Is(err, InvalidTopic) {
   464  					t.Errorf("got wrong error: %+v", err)
   465  				}
   466  				if gen != nil {
   467  					t.Error("expected a nil consumer group generation")
   468  				}
   469  			},
   470  		},
   471  
   472  		{
   473  			scenario: "fails to join group (leader, unsupported protocol)",
   474  			prepare: func(mc *mockCoordinator) {
   475  				mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) {
   476  					return joinGroupResponseV1{
   477  						GenerationID:  12345,
   478  						GroupProtocol: "foo",
   479  						LeaderID:      "abc",
   480  						MemberID:      "abc",
   481  					}, nil
   482  				}
   483  			},
   484  			function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) {
   485  				gen, err := group.Next(ctx)
   486  				if err == nil {
   487  					t.Errorf("expected an error")
   488  				} else if !strings.HasPrefix(err.Error(), "unable to find selected balancer") {
   489  					t.Errorf("got wrong error: %+v", err)
   490  				}
   491  				if gen != nil {
   492  					t.Error("expected a nil consumer group generation")
   493  				}
   494  				assertLeftGroup(t, "abc")
   495  			},
   496  		},
   497  
   498  		{
   499  			scenario: "fails to sync group (general error)",
   500  			prepare: func(mc *mockCoordinator) {
   501  				mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) {
   502  					return joinGroupResponseV1{
   503  						GenerationID:  12345,
   504  						GroupProtocol: "range",
   505  						LeaderID:      "abc",
   506  						MemberID:      "abc",
   507  					}, nil
   508  				}
   509  				mc.readPartitionsFunc = func(...string) ([]Partition, error) {
   510  					return []Partition{}, nil
   511  				}
   512  				mc.syncGroupFunc = func(syncGroupRequestV0) (syncGroupResponseV0, error) {
   513  					return syncGroupResponseV0{}, errors.New("sync group failed")
   514  				}
   515  			},
   516  			function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) {
   517  				gen, err := group.Next(ctx)
   518  				if err == nil {
   519  					t.Errorf("expected an error")
   520  				} else if err.Error() != "sync group failed" {
   521  					t.Errorf("got wrong error: %+v", err)
   522  				}
   523  				if gen != nil {
   524  					t.Error("expected a nil consumer group generation")
   525  				}
   526  				assertLeftGroup(t, "abc")
   527  			},
   528  		},
   529  
   530  		{
   531  			scenario: "fails to sync group (error code)",
   532  			prepare: func(mc *mockCoordinator) {
   533  				mc.syncGroupFunc = func(syncGroupRequestV0) (syncGroupResponseV0, error) {
   534  					return syncGroupResponseV0{
   535  						ErrorCode: int16(InvalidTopic),
   536  					}, nil
   537  				}
   538  			},
   539  			function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) {
   540  				gen, err := group.Next(ctx)
   541  				if err == nil {
   542  					t.Errorf("expected an error")
   543  				} else if !errors.Is(err, InvalidTopic) {
   544  					t.Errorf("got wrong error: %+v", err)
   545  				}
   546  				if gen != nil {
   547  					t.Error("expected a nil consumer group generation")
   548  				}
   549  				assertLeftGroup(t, "abc")
   550  			},
   551  		},
   552  	}
   553  
   554  	for _, tt := range tests {
   555  		t.Run(tt.scenario, func(t *testing.T) {
   556  
   557  			tt.prepare(&mc)
   558  
   559  			group, err := NewConsumerGroup(ConsumerGroupConfig{
   560  				ID:                makeGroupID(),
   561  				Topics:            []string{"test"},
   562  				Brokers:           []string{"no-such-broker"}, // should not attempt to actually dial anything
   563  				HeartbeatInterval: 2 * time.Second,
   564  				RebalanceTimeout:  time.Second,
   565  				JoinGroupBackoff:  time.Second,
   566  				RetentionTime:     time.Hour,
   567  				connect: func(*Dialer, ...string) (coordinator, error) {
   568  					return mc, nil
   569  				},
   570  				Logger: &testKafkaLogger{T: t},
   571  			})
   572  			if err != nil {
   573  				t.Fatal(err)
   574  			}
   575  
   576  			// these tests should all execute fairly quickly since they're
   577  			// mocking the coordinator.
   578  			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   579  			defer cancel()
   580  
   581  			tt.function(t, ctx, group)
   582  
   583  			if err := group.Close(); err != nil {
   584  				t.Errorf("error on close: %+v", err)
   585  			}
   586  		})
   587  	}
   588  }
   589  
   590  // todo : test for multi-topic?
   591  
   592  func TestGenerationExitsOnPartitionChange(t *testing.T) {
   593  	var count int
   594  	partitions := [][]Partition{
   595  		{
   596  			Partition{
   597  				Topic: "topic-1",
   598  				ID:    0,
   599  			},
   600  		},
   601  		{
   602  			Partition{
   603  				Topic: "topic-1",
   604  				ID:    0,
   605  			},
   606  			{
   607  				Topic: "topic-1",
   608  				ID:    1,
   609  			},
   610  		},
   611  	}
   612  
   613  	conn := mockCoordinator{
   614  		readPartitionsFunc: func(...string) ([]Partition, error) {
   615  			p := partitions[count]
   616  			// cap the count at len(partitions) -1 so ReadPartitions doesn't even go out of bounds
   617  			// and long running tests don't fail
   618  			if count < len(partitions) {
   619  				count++
   620  			}
   621  			return p, nil
   622  		},
   623  	}
   624  
   625  	// Sadly this test is time based, so at the end will be seeing if the runGroup run to completion within the
   626  	// allotted time. The allotted time is 4x the PartitionWatchInterval.
   627  	now := time.Now()
   628  	watchTime := 500 * time.Millisecond
   629  
   630  	gen := Generation{
   631  		conn:     conn,
   632  		done:     make(chan struct{}),
   633  		joined:   make(chan struct{}),
   634  		log:      func(func(Logger)) {},
   635  		logError: func(func(Logger)) {},
   636  	}
   637  
   638  	done := make(chan struct{})
   639  	go func() {
   640  		gen.partitionWatcher(watchTime, "topic-1")
   641  		close(done)
   642  	}()
   643  
   644  	select {
   645  	case <-time.After(5 * time.Second):
   646  		t.Fatal("timed out waiting for partition watcher to exit")
   647  	case <-done:
   648  		if time.Since(now).Seconds() > watchTime.Seconds()*4 {
   649  			t.Error("partitionWatcher didn't see update")
   650  		}
   651  	}
   652  }
   653  
   654  func TestGenerationStartsFunctionAfterClosed(t *testing.T) {
   655  	gen := Generation{
   656  		conn:     &mockCoordinator{},
   657  		done:     make(chan struct{}),
   658  		joined:   make(chan struct{}),
   659  		log:      func(func(Logger)) {},
   660  		logError: func(func(Logger)) {},
   661  	}
   662  
   663  	gen.close()
   664  
   665  	ch := make(chan error)
   666  	gen.Start(func(ctx context.Context) {
   667  		<-ctx.Done()
   668  		ch <- ctx.Err()
   669  	})
   670  
   671  	select {
   672  	case <-time.After(time.Second):
   673  		t.Fatal("timed out waiting for func to run")
   674  	case err := <-ch:
   675  		if !errors.Is(err, ErrGenerationEnded) {
   676  			t.Fatalf("expected %v but got %v", ErrGenerationEnded, err)
   677  		}
   678  	}
   679  }