github.com/wfusion/gofusion@v1.1.14/common/infra/watermill/pubsub/sql/subscriber.go (about)

     1  package sql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/oklog/ulid/v2"
    11  	"github.com/pkg/errors"
    12  	"github.com/spf13/cast"
    13  	"go.uber.org/multierr"
    14  
    15  	"github.com/wfusion/gofusion/common/infra/watermill"
    16  	"github.com/wfusion/gofusion/common/infra/watermill/message"
    17  	"github.com/wfusion/gofusion/common/utils"
    18  )
    19  
    20  var (
    21  	ErrSubscriberClosed = errors.New("subscriber is closed")
    22  )
    23  
    24  type SubscriberConfig struct {
    25  	ConsumerGroup string
    26  
    27  	// PollInterval is the interval to wait between subsequent SELECT queries,
    28  	// if no more messages were found in the database (Prefer using the BackoffManager instead).
    29  	// Must be non-negative. Defaults to 1s.
    30  	PollInterval time.Duration
    31  
    32  	// ResendInterval is the time to wait before resending a nacked message.
    33  	// Must be non-negative. Defaults to 1s.
    34  	ResendInterval time.Duration
    35  
    36  	// RetryInterval is the time to wait before resuming querying for messages
    37  	// after an error (Prefer using the BackoffManager instead).
    38  	// Must be non-negative. Defaults to 1s.
    39  	RetryInterval time.Duration
    40  
    41  	// BackoffManager defines how much to backoff when receiving errors.
    42  	BackoffManager BackoffManager
    43  
    44  	// SchemaAdapter provides the schema-dependent queries and arguments for them, based on topic/message etc.
    45  	SchemaAdapter SchemaAdapter
    46  
    47  	// OffsetsAdapter provides mechanism for saving acks and offsets of consumers.
    48  	OffsetsAdapter OffsetsAdapter
    49  
    50  	// InitializeSchema option enables initializing schema on making subscription.
    51  	InitializeSchema bool
    52  
    53  	// DisablePersistent option delete message after consumed
    54  	DisablePersistent bool
    55  }
    56  
    57  func (c *SubscriberConfig) setDefaults() {
    58  	if c.PollInterval == 0 {
    59  		c.PollInterval = time.Second
    60  	}
    61  	if c.ResendInterval == 0 {
    62  		c.ResendInterval = time.Second
    63  	}
    64  	if c.RetryInterval == 0 {
    65  		c.RetryInterval = time.Second
    66  	}
    67  	if c.BackoffManager == nil {
    68  		c.BackoffManager = NewDefaultBackoffManager(c.PollInterval, c.RetryInterval)
    69  	}
    70  }
    71  
    72  func (c SubscriberConfig) validate() error {
    73  	if c.PollInterval <= 0 {
    74  		return errors.New("poll interval must be a positive duration")
    75  	}
    76  	if c.ResendInterval <= 0 {
    77  		return errors.New("resend interval must be a positive duration")
    78  	}
    79  	if c.RetryInterval <= 0 {
    80  		return errors.New("resend interval must be a positive duration")
    81  	}
    82  	if c.SchemaAdapter == nil {
    83  		return errors.New("schema adapter is nil")
    84  	}
    85  	if c.OffsetsAdapter == nil {
    86  		return errors.New("offsets adapter is nil")
    87  	}
    88  
    89  	return nil
    90  }
    91  
    92  // Subscriber makes SELECT queries on the chosen table with the interval defined in the config.
    93  // The rows are unmarshaled into Watermill messages.
    94  type Subscriber struct {
    95  	consumerIdBytes  []byte
    96  	consumerIdString string
    97  
    98  	db     Beginner
    99  	config SubscriberConfig
   100  
   101  	subscribeWg *sync.WaitGroup
   102  	closing     chan struct{}
   103  	closed      uint32
   104  
   105  	logger watermill.LoggerAdapter
   106  }
   107  
   108  func NewSubscriber(db Beginner, config SubscriberConfig, logger watermill.LoggerAdapter) (*Subscriber, error) {
   109  	if db == nil {
   110  		return nil, errors.New("db is nil")
   111  	}
   112  	config.setDefaults()
   113  	err := config.validate()
   114  	if err != nil {
   115  		return nil, errors.Wrap(err, "invalid config")
   116  	}
   117  
   118  	if logger == nil {
   119  		logger = watermill.NopLogger{}
   120  	}
   121  
   122  	idBytes, idStr, err := newSubscriberID()
   123  	if err != nil {
   124  		return &Subscriber{}, errors.Wrap(err, "cannot generate subscriber id")
   125  	}
   126  	logger = logger.With(watermill.LogFields{"subscriber_id": idStr})
   127  
   128  	sub := &Subscriber{
   129  		consumerIdBytes:  idBytes,
   130  		consumerIdString: idStr,
   131  
   132  		db:     db,
   133  		config: config,
   134  
   135  		subscribeWg: &sync.WaitGroup{},
   136  		closing:     make(chan struct{}),
   137  
   138  		logger: logger,
   139  	}
   140  
   141  	return sub, nil
   142  }
   143  
   144  func newSubscriberID() ([]byte, string, error) {
   145  	id := utils.ULID()
   146  	idBytes, err := ulid.MustParseStrict(id).MarshalBinary()
   147  	if err != nil {
   148  		return nil, "", errors.Wrap(err, "cannot marshal subscriber id")
   149  	}
   150  
   151  	return idBytes, id, nil
   152  }
   153  
   154  func (s *Subscriber) Subscribe(ctx context.Context, topic string) (o <-chan *message.Message, err error) {
   155  	if atomic.LoadUint32(&s.closed) == 1 {
   156  		return nil, ErrSubscriberClosed
   157  	}
   158  
   159  	if err = validateTopicName(topic); err != nil {
   160  		return nil, err
   161  	}
   162  
   163  	if s.config.InitializeSchema {
   164  		if err := s.SubscribeInitialize(topic); err != nil {
   165  			return nil, err
   166  		}
   167  	}
   168  
   169  	// the information about closing the subscriber is propagated through ctx
   170  	ctx, cancel := context.WithCancel(ctx)
   171  	out := make(chan *message.Message)
   172  
   173  	s.subscribeWg.Add(1)
   174  	go func() {
   175  		s.consume(ctx, topic, out)
   176  		close(out)
   177  		cancel()
   178  	}()
   179  
   180  	return out, nil
   181  }
   182  
   183  func (s *Subscriber) consume(ctx context.Context, topic string, out chan *message.Message) {
   184  	defer s.subscribeWg.Done()
   185  
   186  	logger := s.logger.With(watermill.LogFields{
   187  		"topic":          topic,
   188  		"consumer_group": s.config.ConsumerGroup,
   189  	})
   190  
   191  	var sleepTime time.Duration = 0
   192  	for {
   193  		select {
   194  		case <-s.closing:
   195  			logger.Info("Discarding queued message, subscriber closing", nil)
   196  			return
   197  
   198  		case <-ctx.Done():
   199  			logger.Info("Stopping consume, context canceled", nil)
   200  			return
   201  
   202  		case <-time.After(sleepTime): // Wait if needed
   203  		}
   204  
   205  		noMsg, err := s.query(ctx, topic, out, logger)
   206  		backoff := s.config.BackoffManager.HandleError(logger, noMsg, err)
   207  		if backoff != 0 {
   208  			if err != nil {
   209  				logger = logger.With(watermill.LogFields{"err": err.Error()})
   210  			}
   211  			logger.Trace("Backing off querying", watermill.LogFields{
   212  				"wait_time": backoff,
   213  			})
   214  		}
   215  		sleepTime = backoff
   216  	}
   217  }
   218  
   219  func (s *Subscriber) query(
   220  	ctx context.Context,
   221  	topic string,
   222  	out chan *message.Message,
   223  	logger watermill.LoggerAdapter,
   224  ) (noMsg bool, err error) {
   225  	txOptions := &sql.TxOptions{
   226  		Isolation: s.config.SchemaAdapter.SubscribeIsolationLevel(),
   227  	}
   228  	tx, err := s.db.BeginTx(ctx, txOptions)
   229  	if err != nil {
   230  		return false, errors.Wrap(err, "could not begin tx for querying")
   231  	}
   232  
   233  	defer func() {
   234  		if err != nil {
   235  			rollbackErr := tx.Rollback()
   236  			if rollbackErr != nil && rollbackErr != sql.ErrTxDone {
   237  				logger.Error("could not rollback tx for querying message", rollbackErr, watermill.LogFields{
   238  					"query_err": err,
   239  				})
   240  			}
   241  		} else {
   242  			commitErr := tx.Commit()
   243  			if commitErr != nil && commitErr != sql.ErrTxDone {
   244  				logger.Error("could not commit tx for querying message", commitErr, nil)
   245  			}
   246  		}
   247  	}()
   248  
   249  	selectQuery, selectQueryArgs := s.config.SchemaAdapter.SelectQuery(
   250  		topic,
   251  		s.config.ConsumerGroup,
   252  		s.config.OffsetsAdapter,
   253  	)
   254  	logger.Trace("[Common] watermill querying message", watermill.LogFields{
   255  		"query":      selectQuery,
   256  		"query_args": sqlArgsToLog(selectQueryArgs),
   257  	})
   258  	rows, err := tx.QueryContext(ctx, selectQuery, selectQueryArgs...)
   259  	if err != nil {
   260  		return false, errors.Wrap(err, "could not query message")
   261  	}
   262  
   263  	defer func() {
   264  		if rowsCloseErr := rows.Close(); rowsCloseErr != nil {
   265  			err = multierr.Append(err, errors.Wrap(err, "could not close rows"))
   266  		}
   267  	}()
   268  
   269  	var lastOffset int64
   270  	var lastRow Row
   271  
   272  	messageRows := make([]Row, 0)
   273  
   274  	for rows.Next() {
   275  		row, err := s.config.SchemaAdapter.UnmarshalMessage(rows)
   276  		if errors.Cause(err) == sql.ErrNoRows {
   277  			return true, nil
   278  		} else if err != nil {
   279  			return false, errors.Wrap(err, "could not unmarshal message from query")
   280  		}
   281  
   282  		messageRows = append(messageRows, row)
   283  	}
   284  
   285  	for _, row := range messageRows {
   286  		consumedQuery, consumedArgs := s.config.OffsetsAdapter.ConsumedMessageQuery(
   287  			topic,
   288  			row,
   289  			s.config.ConsumerGroup,
   290  			s.consumerIdBytes,
   291  		)
   292  		if consumedQuery != "" {
   293  			logger.Trace("[Common] watermill executing query to confirm message consumed", watermill.LogFields{
   294  				"query":      consumedQuery,
   295  				"query_args": sqlArgsToLog(consumedArgs),
   296  			})
   297  
   298  			_, err := tx.ExecContext(ctx, consumedQuery, consumedArgs...)
   299  
   300  			if err != nil {
   301  				return false, errors.Wrap(err, "cannot send consumed query")
   302  			}
   303  
   304  			logger.Trace("[Common] watermill executed query to confirm message consumed", nil)
   305  		}
   306  		logger = logger.With(watermill.LogFields{
   307  			"message_uuid":   row.Msg.UUID,
   308  			"message_raw_id": row.Offset,
   309  		})
   310  		logger.Trace("[Common] watermill received message", nil)
   311  
   312  		msgCtx := setTxToContext(ctx, tx)
   313  		msgCtx = context.WithValue(msgCtx, watermill.ContextKeyMessageUUID, row.Msg.UUID)
   314  		msgCtx = context.WithValue(msgCtx, watermill.ContextKeyRawMessageID, cast.ToString(row.Offset))
   315  		row.Msg.Metadata[watermill.ContextKeyMessageUUID] = string(row.UUID)
   316  		row.Msg.Metadata[watermill.ContextKeyRawMessageID] = cast.ToString(row.Offset)
   317  		acked := s.sendMessage(msgCtx, row.Msg, out, logger)
   318  		if !acked {
   319  			break
   320  		}
   321  		_ = s.deleteMessage(ctx, tx, topic, row.Offset, logger)
   322  		lastOffset = row.Offset
   323  		lastRow = row
   324  	}
   325  
   326  	if lastOffset == 0 {
   327  		return true, nil
   328  	}
   329  
   330  	ackQuery, ackArgs := s.config.OffsetsAdapter.AckMessageQuery(
   331  		topic,
   332  		lastRow,
   333  		s.config.ConsumerGroup,
   334  	)
   335  
   336  	logger.Trace("[Common] watermill executing ack message query", watermill.LogFields{
   337  		"query":      ackQuery,
   338  		"query_args": sqlArgsToLog(ackArgs),
   339  	})
   340  
   341  	result, err := tx.ExecContext(ctx, ackQuery, ackArgs...)
   342  	if err != nil {
   343  		return false, errors.Wrap(err, "could not get args for acking the message")
   344  	}
   345  
   346  	rowsAffected, _ := result.RowsAffected()
   347  
   348  	logger.Trace("[Common] watermill executed ack message query", watermill.LogFields{
   349  		"rows_affected": rowsAffected,
   350  	})
   351  
   352  	return false, nil
   353  }
   354  
   355  // sendMessages sends messages on the output channel.
   356  func (s *Subscriber) sendMessage(
   357  	ctx context.Context,
   358  	msg *message.Message,
   359  	out chan *message.Message,
   360  	logger watermill.LoggerAdapter,
   361  ) (acked bool) {
   362  	msgCtx, cancel := context.WithCancel(ctx)
   363  	msg.SetContext(msgCtx)
   364  	defer cancel()
   365  
   366  ResendLoop:
   367  	for {
   368  
   369  		select {
   370  		case out <- msg:
   371  
   372  		case <-s.closing:
   373  			logger.Info("[Common] watermill discarding queued message, subscriber closing", nil)
   374  			return false
   375  
   376  		case <-ctx.Done():
   377  			logger.Info("[Common] watermill discarding queued message, context canceled", nil)
   378  			return false
   379  		}
   380  
   381  		select {
   382  		case <-msg.Acked():
   383  			logger.Debug("[Common] watermill message acked by subscriber", nil)
   384  			return true
   385  
   386  		case <-msg.Nacked():
   387  			// message nacked, try resending
   388  			logger.Debug("[Common] watermill message nacked, resending", nil)
   389  			msg = msg.Copy()
   390  			msg.SetContext(msgCtx)
   391  
   392  			if s.config.ResendInterval != 0 {
   393  				time.Sleep(s.config.ResendInterval)
   394  			}
   395  
   396  			continue ResendLoop
   397  
   398  		case <-s.closing:
   399  			logger.Info("[Common] watermill discarding queued message, subscriber closing", nil)
   400  			return false
   401  
   402  		case <-ctx.Done():
   403  			logger.Info("[Common] watermill discarding queued message, context canceled", nil)
   404  			return false
   405  		}
   406  	}
   407  }
   408  
   409  func (s *Subscriber) deleteMessage(ctx context.Context, tx *sql.Tx,
   410  	topic string, offset int64, logger watermill.LoggerAdapter) (err error) {
   411  	if !s.config.DisablePersistent {
   412  		return
   413  	}
   414  
   415  	deleteQuery, deleteArgs := s.config.SchemaAdapter.DeleteQuery(topic, offset)
   416  	logger.Trace("[Common] watermill executing delete message query", watermill.LogFields{
   417  		"query":      deleteQuery,
   418  		"query_args": sqlArgsToLog(deleteArgs),
   419  	})
   420  	result, err := tx.ExecContext(ctx, deleteQuery, deleteArgs...)
   421  	if err != nil {
   422  		return errors.Wrap(err, "execute delete message query failed")
   423  	}
   424  
   425  	rowsAffected, _ := result.RowsAffected()
   426  
   427  	logger.Trace("[Common] watermill executed delete message query", watermill.LogFields{
   428  		"rows_affected": rowsAffected,
   429  	})
   430  
   431  	return
   432  }
   433  
   434  func (s *Subscriber) Close() error {
   435  	if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
   436  		return nil
   437  	}
   438  
   439  	close(s.closing)
   440  	s.subscribeWg.Wait()
   441  
   442  	return nil
   443  }
   444  
   445  func (s *Subscriber) SubscribeInitialize(topic string) error {
   446  	return initializeSchema(
   447  		context.Background(),
   448  		topic,
   449  		s.logger,
   450  		s.db,
   451  		s.config.SchemaAdapter,
   452  		s.config.OffsetsAdapter,
   453  	)
   454  }