github.com/Goboolean/common@v0.0.0-20231130153141-cb54596b217d/pkg/kafka/consumer.go (about)

     1  package kafka
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/Goboolean/common/pkg/resolver"
     9  	"github.com/confluentinc/confluent-kafka-go/kafka"
    10  	"github.com/confluentinc/confluent-kafka-go/schemaregistry"
    11  	"github.com/confluentinc/confluent-kafka-go/schemaregistry/serde"
    12  	"github.com/confluentinc/confluent-kafka-go/schemaregistry/serde/protobuf"
    13  	log "github.com/sirupsen/logrus"
    14  	"google.golang.org/protobuf/proto"
    15  	"google.golang.org/protobuf/reflect/protoreflect"
    16  )
    17  
    18  
    19  
    20  type Deserializer interface {
    21  	DeserializeInto(topic string, payload []byte, msg interface{}) error
    22  }
    23  
    24  type ProtoDeserializer struct {}
    25  
    26  func (s *ProtoDeserializer) DeserializeInto(topic string, payload []byte, msg interface{}) error {
    27  	var err error
    28  	defer func() {
    29  		if r := recover(); r != nil {
    30  			err = ErrReceivedMsgIsNotProtoMessage
    31  		}
    32  	}()
    33  
    34  	if err := proto.Unmarshal(payload, msg.(proto.Message)); err != nil {
    35  		return err
    36  	}
    37  	return err
    38  }
    39  
    40  func defaultDeserializer() Deserializer {
    41  	return &ProtoDeserializer{}
    42  }
    43  
    44  
    45  type SubscribeListener[T proto.Message] interface {
    46  	OnReceiveMessage(ctx context.Context, msg T) error
    47  }
    48  
    49  
    50  type Consumer[T proto.Message] struct {
    51  	consumer *kafka.Consumer
    52  	deserial Deserializer
    53  
    54  	listener SubscribeListener[T]
    55  	topic string
    56  	channel chan T
    57  
    58  	wg     sync.WaitGroup
    59  	ctx    context.Context
    60  	cancel context.CancelFunc
    61  }
    62  
    63  // example:
    64  // p, err := NewConsumer[*model.Event](&resolver.ConfigMap{
    65  //   "BOOTSTRAP_HOST": os.Getenv("KAFKA_BOOTSTRAP_HOST"),
    66  //   "REGISTRY_HOST":  os.Getenv("KAFKA_REGISTRY_HOST"), // optional
    67  //   "GROUP_ID":       "GROUP_ID",
    68  //   "PROCESSOR_COUNT": os.Getenv("KAFKA_PROCESSOR_COUNT"),
    69  //   "TOPIC":          "TOPIC",
    70  // }, subscribeListenerImpl)
    71  func NewConsumer[T proto.Message](c *resolver.ConfigMap, l SubscribeListener[T]) (*Consumer[T], error) {
    72  
    73  	bootstrap_host, err := c.GetStringKey("BOOTSTRAP_HOST")
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	group_id, err := c.GetStringKey("GROUP_ID")
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	registry_url, exists, err := c.GetStringKeyOptional("REGISTRY_URL")
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	processor_count, err := c.GetIntKey("PROCESSOR_COUNT")
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	conn, err := kafka.NewConsumer(&kafka.ConfigMap{
    94  		"bootstrap.servers":   bootstrap_host,
    95  		"group.id":            group_id,
    96  		"auto.offset.reset": "earliest",
    97  	})
    98  
    99  	ctx, cancel := context.WithCancel(context.Background())
   100  
   101  	instance := &Consumer[T]{
   102  		consumer: conn,
   103  		listener: l,
   104  		wg: sync.WaitGroup{},
   105  		ctx: ctx,
   106  		cancel: cancel,
   107  		channel: make(chan T, 100),
   108  	}
   109  
   110  	if exists {
   111  		sr, err := schemaregistry.NewClient(schemaregistry.NewConfig(registry_url))
   112  		if err != nil {
   113  			return nil, err
   114  		}
   115  
   116  		d, err := protobuf.NewDeserializer(sr, serde.ValueSerde, protobuf.NewDeserializerConfig())
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  
   121  		instance.deserial = d
   122  	} else {
   123  		instance.deserial = defaultDeserializer()
   124  	}
   125  
   126  	go instance.readMessage(ctx, &instance.wg)
   127  	for i := 0; i < processor_count; i++ {
   128  		go instance.consumeMessage(ctx, &instance.wg)
   129  	}
   130  	return instance, nil
   131  }
   132  
   133  
   134  func (c *Consumer[T]) Subscribe(topic string, schema protoreflect.MessageType) error {
   135  	if c.topic != "" {
   136  		return ErrTopicAlreadySubscribed
   137  	}
   138  
   139  	_, ok := c.deserial.(*protobuf.Deserializer)
   140  	if ok {
   141  		if err := c.deserial.(*protobuf.Deserializer).ProtoRegistry.RegisterMessage(schema); err != nil {
   142  			return err
   143  		}
   144  	}
   145  
   146  	if err := c.consumer.Subscribe(topic, nil); err != nil {
   147  		return err
   148  	}
   149  	c.topic = topic
   150  	return nil
   151  }
   152  
   153  
   154  func (c *Consumer[T]) readMessage(ctx context.Context, wg *sync.WaitGroup) {
   155  	go func() {
   156  		c.wg.Add(1)
   157  		defer c.wg.Done()
   158  
   159  		for {
   160  			if err := c.ctx.Err(); err != nil {
   161  				return
   162  			}
   163  
   164  			msg, err := c.consumer.ReadMessage(time.Second * 1)
   165  			if err != nil {
   166  				continue
   167  			}
   168  
   169  			var event T
   170  			if err := c.deserial.DeserializeInto(c.topic, msg.Value, event); err != nil {
   171  				log.WithFields(log.Fields{
   172  					"topic": *msg.TopicPartition.Topic,
   173  					"data":  msg.Value,
   174  					"error": err,
   175  				}).Error("Failed to deserialize data")
   176  				continue
   177  			}
   178  
   179  			log.WithFields(log.Fields{
   180  				"topic": *msg.TopicPartition.Topic,
   181  				"data":  msg.Value,
   182  				"partition":  msg.TopicPartition.Partition,
   183  				"offset": msg.TopicPartition.Offset,
   184  			}).Trace("Consumer received message")
   185  
   186  			c.channel <- event
   187  		}
   188  	}()
   189  }
   190  
   191  
   192  func (c *Consumer[T]) consumeMessage(ctx context.Context, wg *sync.WaitGroup) {
   193  	wg.Add(1)
   194  	defer wg.Done()
   195  
   196  	for {
   197  		select {
   198  		case <-ctx.Done():
   199  			return
   200  		case event := <-c.channel:
   201  			ctx, cancel := context.WithTimeout(c.ctx, time.Second*5)
   202  			if err := c.listener.OnReceiveMessage(ctx, event); err != nil {
   203  				log.WithFields(log.Fields{
   204  					"event":  event,
   205  					"error": err,
   206  				}).Error("Failed to process data")
   207  			}
   208  			cancel()
   209  		}
   210  	}
   211  }
   212  
   213  
   214  func (c *Consumer[T]) Close() {
   215  	c.cancel()
   216  	time.Sleep(time.Second * 1)
   217  	c.consumer.Close()
   218  	c.wg.Wait()
   219  }
   220  
   221  
   222  func (c *Consumer[T]) Ping(ctx context.Context) error {
   223  	// It requires ctx to be deadline set, otherwise it will return error
   224  	// It will return error if there is no response within deadline
   225  	deadline, ok := ctx.Deadline()
   226  	if !ok {
   227  		return ErrDeadlineSettingRequired
   228  	}
   229  
   230  	remaining := time.Until(deadline)
   231  	_, err := c.consumer.GetMetadata(nil, true, int(remaining.Milliseconds()))
   232  	return err
   233  }