github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/addpartitionstotxn_test.go (about)

     1  package kafka
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"strconv"
     7  	"testing"
     8  	"time"
     9  
    10  	ktesting "github.com/segmentio/kafka-go/testing"
    11  )
    12  
    13  func TestClientAddPartitionsToTxn(t *testing.T) {
    14  	if !ktesting.KafkaIsAtLeast("0.11.0") {
    15  		t.Skip("Skipping test because kafka version is not high enough.")
    16  	}
    17  	topic1 := makeTopic()
    18  	topic2 := makeTopic()
    19  
    20  	client, shutdown := newLocalClient()
    21  	defer shutdown()
    22  
    23  	err := clientCreateTopic(client, topic1, 3)
    24  	if err != nil {
    25  		t.Fatal(err)
    26  	}
    27  
    28  	err = clientCreateTopic(client, topic2, 3)
    29  	if err != nil {
    30  		t.Fatal(err)
    31  	}
    32  
    33  	transactionalID := makeTransactionalID()
    34  
    35  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
    36  	defer cancel()
    37  	respc, err := waitForCoordinatorIndefinitely(ctx, client, &FindCoordinatorRequest{
    38  		Addr:    client.Addr,
    39  		Key:     transactionalID,
    40  		KeyType: CoordinatorKeyTypeTransaction,
    41  	})
    42  	if err != nil {
    43  		t.Fatal(err)
    44  	}
    45  
    46  	transactionCoordinator := TCP(net.JoinHostPort(respc.Coordinator.Host, strconv.Itoa(int(respc.Coordinator.Port))))
    47  	client, shutdown = newClient(transactionCoordinator)
    48  	defer shutdown()
    49  
    50  	ipResp, err := client.InitProducerID(ctx, &InitProducerIDRequest{
    51  		TransactionalID:      transactionalID,
    52  		TransactionTimeoutMs: 10000,
    53  	})
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  
    58  	if ipResp.Error != nil {
    59  		t.Fatal(ipResp.Error)
    60  	}
    61  
    62  	defer func() {
    63  		err := clientEndTxn(client, &EndTxnRequest{
    64  			TransactionalID: transactionalID,
    65  			ProducerID:      ipResp.Producer.ProducerID,
    66  			ProducerEpoch:   ipResp.Producer.ProducerEpoch,
    67  			Committed:       false,
    68  		})
    69  		if err != nil {
    70  			t.Fatal(err)
    71  		}
    72  	}()
    73  
    74  	ctx, cancel = context.WithTimeout(context.Background(), time.Second*30)
    75  	defer cancel()
    76  	resp, err := client.AddPartitionsToTxn(ctx, &AddPartitionsToTxnRequest{
    77  		TransactionalID: transactionalID,
    78  		ProducerID:      ipResp.Producer.ProducerID,
    79  		ProducerEpoch:   ipResp.Producer.ProducerEpoch,
    80  		Topics: map[string][]AddPartitionToTxn{
    81  			topic1: {
    82  				{
    83  					Partition: 0,
    84  				},
    85  				{
    86  					Partition: 1,
    87  				},
    88  				{
    89  					Partition: 2,
    90  				},
    91  			},
    92  			topic2: {
    93  				{
    94  					Partition: 0,
    95  				},
    96  				{
    97  					Partition: 2,
    98  				},
    99  			},
   100  		},
   101  	})
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  
   106  	if len(resp.Topics) != 2 {
   107  		t.Errorf("expected responses for 2 topics; got: %d", len(resp.Topics))
   108  	}
   109  	for topic, partitions := range resp.Topics {
   110  		if topic == topic1 {
   111  			if len(partitions) != 3 {
   112  				t.Errorf("expected 3 partitions in response for topic %s; got: %d", topic, len(partitions))
   113  			}
   114  		}
   115  		if topic == topic2 {
   116  			if len(partitions) != 2 {
   117  				t.Errorf("expected 2 partitions in response for topic %s; got: %d", topic, len(partitions))
   118  			}
   119  		}
   120  		for _, partition := range partitions {
   121  			if partition.Error != nil {
   122  				t.Error(partition.Error)
   123  			}
   124  		}
   125  	}
   126  }