github.com/argoproj/argo-events@v1.9.1/eventbus/jetstream/sensor/trigger_conn.go (about)

     1  package sensor
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/Knetic/govaluate"
    13  	cloudevents "github.com/cloudevents/sdk-go/v2"
    14  	nats "github.com/nats-io/nats.go"
    15  
    16  	eventbuscommon "github.com/argoproj/argo-events/eventbus/common"
    17  	jetstreambase "github.com/argoproj/argo-events/eventbus/jetstream/base"
    18  )
    19  
    20  type JetstreamTriggerConn struct {
    21  	*jetstreambase.JetstreamConnection
    22  	sensorName           string
    23  	triggerName          string
    24  	keyValueStore        nats.KeyValue
    25  	dependencyExpression string
    26  	requiresANDLogic     bool
    27  	evaluableExpression  *govaluate.EvaluableExpression
    28  	deps                 []eventbuscommon.Dependency
    29  	sourceDepMap         map[string][]string // maps EventSource and EventName to dependency name
    30  	recentMsgsByID       map[string]*msg     // prevent re-processing the same message as before (map of msg ID to time)
    31  	recentMsgsByTime     []*msg
    32  }
    33  
    34  type msg struct {
    35  	time  int64
    36  	msgID string
    37  }
    38  
    39  func NewJetstreamTriggerConn(conn *jetstreambase.JetstreamConnection,
    40  	sensorName string,
    41  	triggerName string,
    42  	dependencyExpression string,
    43  	deps []eventbuscommon.Dependency) (*JetstreamTriggerConn, error) {
    44  	var err error
    45  
    46  	sourceDepMap := make(map[string][]string)
    47  	for _, d := range deps {
    48  		key := d.EventSourceName + "__" + d.EventName
    49  		_, found := sourceDepMap[key]
    50  		if !found {
    51  			sourceDepMap[key] = make([]string, 0)
    52  		}
    53  		sourceDepMap[key] = append(sourceDepMap[key], d.Name)
    54  	}
    55  
    56  	connection := &JetstreamTriggerConn{
    57  		JetstreamConnection:  conn,
    58  		sensorName:           sensorName,
    59  		triggerName:          triggerName,
    60  		dependencyExpression: dependencyExpression,
    61  		requiresANDLogic:     strings.Contains(dependencyExpression, "&"),
    62  		deps:                 deps,
    63  		sourceDepMap:         sourceDepMap,
    64  		recentMsgsByID:       make(map[string]*msg),
    65  		recentMsgsByTime:     make([]*msg, 0)}
    66  	connection.Logger = connection.Logger.With("triggerName", connection.triggerName, "sensorName", connection.sensorName)
    67  
    68  	connection.evaluableExpression, err = govaluate.NewEvaluableExpression(strings.ReplaceAll(dependencyExpression, "-", "\\-"))
    69  	if err != nil {
    70  		errStr := fmt.Sprintf("failed to evaluate expression %s: %v", dependencyExpression, err)
    71  		connection.Logger.Error(errStr)
    72  		return nil, fmt.Errorf(errStr)
    73  	}
    74  
    75  	connection.keyValueStore, err = conn.JSContext.KeyValue(sensorName)
    76  	if err != nil {
    77  		errStr := fmt.Sprintf("failed to get K/V store for sensor %s: %v", sensorName, err)
    78  		connection.Logger.Error(errStr)
    79  		return nil, fmt.Errorf(errStr)
    80  	}
    81  
    82  	connection.Logger.Infof("Successfully located K/V store for sensor %s", sensorName)
    83  	return connection, nil
    84  }
    85  
    86  func (conn *JetstreamTriggerConn) IsClosed() bool {
    87  	return conn == nil || conn.JetstreamConnection.IsClosed()
    88  }
    89  
    90  func (conn *JetstreamTriggerConn) Close() error {
    91  	if conn == nil {
    92  		return fmt.Errorf("can't close Jetstream trigger connection, JetstreamTriggerConn is nil")
    93  	}
    94  	return conn.JetstreamConnection.Close()
    95  }
    96  
    97  func (conn *JetstreamTriggerConn) String() string {
    98  	if conn == nil {
    99  		return ""
   100  	}
   101  	return fmt.Sprintf("JetstreamTriggerConn{Sensor:%s,Trigger:%s}", conn.sensorName, conn.triggerName)
   102  }
   103  
   104  func (conn *JetstreamTriggerConn) Subscribe(ctx context.Context,
   105  	closeCh <-chan struct{},
   106  	resetConditionsCh <-chan struct{},
   107  	lastResetTime time.Time,
   108  	transform func(depName string, event cloudevents.Event) (*cloudevents.Event, error),
   109  	filter func(string, cloudevents.Event) bool,
   110  	action func(map[string]cloudevents.Event),
   111  	defaultSubject *string) error {
   112  	if conn == nil {
   113  		return fmt.Errorf("Subscribe() failed; JetstreamTriggerConn is nil")
   114  	}
   115  
   116  	var err error
   117  	log := conn.Logger
   118  	// derive subjects that we'll subscribe with using the dependencies passed in
   119  	subjects := make(map[string]eventbuscommon.Dependency)
   120  	for _, dep := range conn.deps {
   121  		subjects[fmt.Sprintf("default.%s.%s", dep.EventSourceName, dep.EventName)] = dep
   122  	}
   123  
   124  	if !lastResetTime.IsZero() {
   125  		err = conn.clearAllDependencies(&lastResetTime)
   126  		if err != nil {
   127  			errStr := fmt.Sprintf("failed to clear all dependencies as a result of condition reset time; err=%v", err)
   128  			log.Error(errStr)
   129  		}
   130  	}
   131  
   132  	ch := make(chan *nats.Msg) // channel with no buffer (I believe this should be okay - we will block writing messages to this channel while a message is still being processed but volume of messages shouldn't be so high as to cause a problem)
   133  	wg := sync.WaitGroup{}
   134  	processMsgsCloseCh := make(chan struct{})
   135  	pullSubscribeCloseCh := make(map[string]chan struct{}, len(subjects))
   136  
   137  	subscriptions := make([]*nats.Subscription, len(subjects))
   138  	subscriptionIndex := 0
   139  
   140  	// start the goroutines that will listen to the individual subscriptions
   141  	for subject, dependency := range subjects {
   142  		// set durable name separately for each subscription
   143  		durableName := getDurableName(conn.sensorName, conn.triggerName, dependency.Name)
   144  
   145  		conn.Logger.Debugf("durable name for sensor='%s', trigger='%s', dep='%s': '%s'", conn.sensorName, conn.triggerName, dependency.Name, durableName)
   146  		log.Infof("Subscribing to subject %s with durable name %s", subject, durableName)
   147  		subscriptions[subscriptionIndex], err = conn.JSContext.PullSubscribe(subject, durableName, nats.AckExplicit(), nats.DeliverNew())
   148  		if err != nil {
   149  			errorStr := fmt.Sprintf("Failed to subscribe to subject %s using group %s: %v", subject, durableName, err)
   150  			log.Error(errorStr)
   151  			return fmt.Errorf(errorStr)
   152  		} else {
   153  			log.Debugf("successfully subscribed to subject %s with durable name %s", subject, durableName)
   154  		}
   155  
   156  		pullSubscribeCloseCh[subject] = make(chan struct{})
   157  		go conn.pullSubscribe(subscriptions[subscriptionIndex], ch, pullSubscribeCloseCh[subject], &wg)
   158  		wg.Add(1)
   159  		log.Debug("adding 1 to WaitGroup (pullSubscribe)")
   160  
   161  		subscriptionIndex++
   162  	}
   163  
   164  	// create a single goroutine which which handle receiving messages to ensure that all of the processing is occurring on that
   165  	// one goroutine and we don't need to worry about race conditions
   166  	go conn.processMsgs(ch, processMsgsCloseCh, resetConditionsCh, transform, filter, action, &wg)
   167  	wg.Add(1)
   168  	log.Debug("adding 1 to WaitGroup (processMsgs)")
   169  
   170  	for {
   171  		select {
   172  		case <-ctx.Done():
   173  			log.Info("exiting, closing connection...")
   174  			conn.shutdownSubscriptions(processMsgsCloseCh, pullSubscribeCloseCh, &wg)
   175  			return nil
   176  		case <-closeCh:
   177  			log.Info("closing connection...")
   178  			conn.shutdownSubscriptions(processMsgsCloseCh, pullSubscribeCloseCh, &wg)
   179  			return nil
   180  		}
   181  	}
   182  }
   183  
   184  func (conn *JetstreamTriggerConn) shutdownSubscriptions(processMsgsCloseCh chan struct{}, pullSubscribeCloseCh map[string]chan struct{}, wg *sync.WaitGroup) {
   185  	processMsgsCloseCh <- struct{}{}
   186  	for _, ch := range pullSubscribeCloseCh {
   187  		ch <- struct{}{}
   188  	}
   189  	wg.Wait()
   190  	conn.NATSConn.Close()
   191  	conn.Logger.Debug("closed NATSConn")
   192  }
   193  
   194  func (conn *JetstreamTriggerConn) pullSubscribe(
   195  	subscription *nats.Subscription,
   196  	msgChannel chan<- *nats.Msg,
   197  	closeCh <-chan struct{},
   198  	wg *sync.WaitGroup) {
   199  	var previousErr error
   200  	var previousErrTime time.Time
   201  
   202  	for {
   203  		// call Fetch with timeout
   204  		msgs, fetchErr := subscription.Fetch(1, nats.MaxWait(time.Second*1))
   205  		if fetchErr != nil && !errors.Is(fetchErr, nats.ErrTimeout) {
   206  			if previousErr != fetchErr || time.Since(previousErrTime) > 10*time.Second {
   207  				// avoid log spew - only log error every 10 seconds
   208  				conn.Logger.Errorf("failed to fetch messages for subscription %+v, %v, previousErr=%v, previousErrTime=%v", subscription, fetchErr, previousErr, previousErrTime)
   209  			}
   210  			previousErr = fetchErr
   211  			previousErrTime = time.Now()
   212  		}
   213  
   214  		// read from close channel but don't block if it's empty
   215  		select {
   216  		case <-closeCh:
   217  			wg.Done()
   218  			conn.Logger.Debug("wg.Done(): pullSubscribe")
   219  			conn.Logger.Infof("exiting pullSubscribe() for subscription %+v", subscription)
   220  			return
   221  		default:
   222  		}
   223  		if fetchErr != nil && !errors.Is(fetchErr, nats.ErrTimeout) {
   224  			continue
   225  		}
   226  
   227  		// then push the msgs to the channel which will consume them
   228  		for _, msg := range msgs {
   229  			msgChannel <- msg
   230  		}
   231  	}
   232  }
   233  
   234  func (conn *JetstreamTriggerConn) processMsgs(
   235  	receiveChannel <-chan *nats.Msg,
   236  	closeCh <-chan struct{},
   237  	resetConditionsCh <-chan struct{},
   238  	transform func(depName string, event cloudevents.Event) (*cloudevents.Event, error),
   239  	filter func(string, cloudevents.Event) bool,
   240  	action func(map[string]cloudevents.Event),
   241  	wg *sync.WaitGroup) {
   242  	defer func() {
   243  		wg.Done()
   244  		conn.Logger.Debug("wg.Done(): processMsgs")
   245  	}()
   246  
   247  	for {
   248  		select {
   249  		case msg := <-receiveChannel:
   250  			conn.processMsg(msg, transform, filter, action)
   251  		case <-resetConditionsCh:
   252  			conn.Logger.Info("reset conditions")
   253  			_ = conn.clearAllDependencies(nil)
   254  		case <-closeCh:
   255  			conn.Logger.Info("shutting down processMsgs routine")
   256  			return
   257  		}
   258  	}
   259  }
   260  
   261  func (conn *JetstreamTriggerConn) processMsg(
   262  	m *nats.Msg,
   263  	transform func(depName string, event cloudevents.Event) (*cloudevents.Event, error),
   264  	filter func(string, cloudevents.Event) bool,
   265  	action func(map[string]cloudevents.Event)) {
   266  	meta, err := m.Metadata()
   267  	if err != nil {
   268  		conn.Logger.Errorf("can't get Metadata() for message %+v??", m)
   269  	}
   270  
   271  	done := make(chan bool)
   272  	go func() {
   273  		ticker := time.NewTicker(500 * time.Millisecond)
   274  		defer ticker.Stop()
   275  		for {
   276  			select {
   277  			case <-done:
   278  				err = m.AckSync()
   279  				if err != nil {
   280  					errStr := fmt.Sprintf("Error performing AckSync() on message: %v", err)
   281  					conn.Logger.Error(errStr)
   282  				}
   283  				conn.Logger.Debugf("acked message of Stream seq: %s:%d, Consumer seq: %s:%d", meta.Stream, meta.Sequence.Stream, meta.Consumer, meta.Sequence.Consumer)
   284  				return
   285  			case <-ticker.C:
   286  				err = m.InProgress()
   287  				if err != nil {
   288  					errStr := fmt.Sprintf("Error performing InProgess() on message: %v", err)
   289  					conn.Logger.Error(errStr)
   290  				}
   291  				conn.Logger.Debugf("InProgess message of Stream seq: %s:%d, Consumer seq: %s:%d", meta.Stream, meta.Sequence.Stream, meta.Consumer, meta.Sequence.Consumer)
   292  			}
   293  		}
   294  	}()
   295  
   296  	defer func() {
   297  		done <- true
   298  	}()
   299  
   300  	log := conn.Logger
   301  
   302  	var event *cloudevents.Event
   303  	if err := json.Unmarshal(m.Data, &event); err != nil {
   304  		log.Errorf("Failed to convert to a cloudevent, discarding it... err: %v", err)
   305  		return
   306  	}
   307  
   308  	// De-duplication
   309  	// In the off chance that we receive the same message twice, don't re-process
   310  	_, alreadyReceived := conn.recentMsgsByID[event.ID()]
   311  	if alreadyReceived {
   312  		log.Debugf("already received message of ID %d, ignore this", event.ID())
   313  		return
   314  	}
   315  
   316  	// get all dependencies for this Trigger that match
   317  	depNames, err := conn.getDependencyNames(event.Source(), event.Subject())
   318  	if err != nil || len(depNames) == 0 {
   319  		log.Errorf("Failed to get the dependency names, discarding it... err: %v", err)
   320  		return
   321  	}
   322  
   323  	log.Debugf("New incoming Event Source Message, dependency names=%s, Stream seq: %s:%d, Consumer seq: %s:%d",
   324  		depNames, meta.Stream, meta.Sequence.Stream, meta.Consumer, meta.Sequence.Consumer)
   325  
   326  	for _, depName := range depNames {
   327  		conn.processDependency(m, event, depName, transform, filter, action)
   328  	}
   329  
   330  	// Save message for de-duplication purposes
   331  	conn.storeMessageID(event.ID())
   332  	conn.purgeOldMsgs()
   333  }
   334  
   335  func (conn *JetstreamTriggerConn) processDependency(
   336  	m *nats.Msg,
   337  	event *cloudevents.Event,
   338  	depName string,
   339  	transform func(depName string, event cloudevents.Event) (*cloudevents.Event, error),
   340  	filter func(string, cloudevents.Event) bool,
   341  	action func(map[string]cloudevents.Event)) {
   342  	log := conn.Logger
   343  	event, err := transform(depName, *event)
   344  	if err != nil {
   345  		log.Errorw("failed to apply event transformation, ", err)
   346  		return
   347  	}
   348  
   349  	if !filter(depName, *event) {
   350  		// message not interested
   351  		log.Infof("not interested in dependency %s (didn't pass filter)", depName)
   352  		return
   353  	}
   354  
   355  	if !conn.requiresANDLogic {
   356  		// this is the simple case: we can just perform the trigger
   357  		messages := make(map[string]cloudevents.Event)
   358  		messages[depName] = *event
   359  		log.Infof("Triggering actions after receiving dependency %s", depName)
   360  
   361  		action(messages)
   362  	} else {
   363  		// check Dependency expression (need to retrieve previous dependencies from Key/Value store)
   364  
   365  		prevMsgs, err := conn.getSavedDependencies()
   366  		if err != nil {
   367  			return
   368  		}
   369  
   370  		// populate 'parameters' map to indicate which dependencies have been received and which haven't
   371  		parameters := make(map[string]interface{}, len(conn.deps))
   372  		for _, dep := range conn.deps {
   373  			parameters[dep.Name] = false
   374  		}
   375  		for prevDep := range prevMsgs {
   376  			parameters[prevDep] = true
   377  		}
   378  		parameters[depName] = true
   379  		log.Infof("Current state of dependencies: %v", parameters)
   380  
   381  		// evaluate the filter expression
   382  		result, err := conn.evaluableExpression.Evaluate(parameters)
   383  		if err != nil {
   384  			errStr := fmt.Sprintf("failed to evaluate dependency expression: %v", err)
   385  			log.Error(errStr)
   386  			return
   387  		}
   388  
   389  		// if expression is true, trigger and clear the K/V store
   390  		// else save the new message in the K/V store
   391  		if result == true {
   392  			log.Debugf("dependency expression successfully evaluated to true: '%s'", conn.dependencyExpression)
   393  
   394  			messages := make(map[string]cloudevents.Event, len(prevMsgs)+1)
   395  			for prevDep, msgInfo := range prevMsgs {
   396  				messages[prevDep] = *msgInfo.Event
   397  			}
   398  			messages[depName] = *event
   399  			log.Infof("Triggering actions after receiving dependency %s", depName)
   400  
   401  			action(messages)
   402  
   403  			_ = conn.clearAllDependencies(nil)
   404  		} else {
   405  			log.Debugf("dependency expression false: %s", conn.dependencyExpression)
   406  			msgMetadata, err := m.Metadata()
   407  			if err != nil {
   408  				errStr := fmt.Sprintf("message %+v is not a jetstream message???: %v", m, err)
   409  				log.Error(errStr)
   410  				return
   411  			}
   412  			_ = conn.saveDependency(depName,
   413  				MsgInfo{
   414  					StreamSeq:   msgMetadata.Sequence.Stream,
   415  					ConsumerSeq: msgMetadata.Sequence.Consumer,
   416  					Timestamp:   msgMetadata.Timestamp,
   417  					Event:       event})
   418  		}
   419  	}
   420  }
   421  
   422  func (conn *JetstreamTriggerConn) getSavedDependencies() (map[string]MsgInfo, error) {
   423  	// dependencies are formatted "<Sensor>/<Trigger>/<Dependency>""
   424  	prevMsgs := make(map[string]MsgInfo)
   425  
   426  	// for each dependency that's in our dependency expression, look for it:
   427  	for _, dep := range conn.deps {
   428  		msgInfo, found, err := conn.getSavedDependency(dep.Name)
   429  		if err != nil {
   430  			return prevMsgs, err
   431  		}
   432  		if found {
   433  			prevMsgs[dep.Name] = msgInfo
   434  		}
   435  	}
   436  
   437  	return prevMsgs, nil
   438  }
   439  
   440  func (conn *JetstreamTriggerConn) getSavedDependency(depName string) (msg MsgInfo, found bool, err error) {
   441  	key := getDependencyKey(conn.triggerName, depName)
   442  	entry, err := conn.keyValueStore.Get(key)
   443  	if err == nil {
   444  		if entry != nil {
   445  			var msgInfo MsgInfo
   446  			err := json.Unmarshal(entry.Value(), &msgInfo)
   447  			if err != nil {
   448  				errStr := fmt.Sprintf("error unmarshalling value %s for key %s: %v", string(entry.Value()), key, err)
   449  				conn.Logger.Error(errStr)
   450  				return MsgInfo{}, true, fmt.Errorf(errStr)
   451  			}
   452  			return msgInfo, true, nil
   453  		}
   454  	} else if err != nats.ErrKeyNotFound {
   455  		return MsgInfo{}, false, err
   456  	}
   457  
   458  	return MsgInfo{}, false, nil
   459  }
   460  
   461  func (conn *JetstreamTriggerConn) saveDependency(depName string, msgInfo MsgInfo) error {
   462  	log := conn.Logger
   463  	jsonEncodedMsg, err := json.Marshal(msgInfo)
   464  	if err != nil {
   465  		errorStr := fmt.Sprintf("failed to convert msgInfo struct into JSON: %+v", msgInfo)
   466  		log.Error(errorStr)
   467  		return fmt.Errorf(errorStr)
   468  	}
   469  	key := getDependencyKey(conn.triggerName, depName)
   470  
   471  	_, err = conn.keyValueStore.Put(key, jsonEncodedMsg)
   472  	if err != nil {
   473  		errorStr := fmt.Sprintf("failed to store dependency under key %s, value:%s: %+v", key, jsonEncodedMsg, err)
   474  		log.Error(errorStr)
   475  		return fmt.Errorf(errorStr)
   476  	}
   477  
   478  	return nil
   479  }
   480  
   481  func (conn *JetstreamTriggerConn) clearAllDependencies(beforeTimeOpt *time.Time) error {
   482  	for _, dep := range conn.deps {
   483  		if beforeTimeOpt != nil && !beforeTimeOpt.IsZero() {
   484  			err := conn.clearDependencyIfExistsBeforeTime(dep.Name, *beforeTimeOpt)
   485  			if err != nil {
   486  				return err
   487  			}
   488  		} else {
   489  			err := conn.clearDependencyIfExists(dep.Name)
   490  			if err != nil {
   491  				return err
   492  			}
   493  		}
   494  	}
   495  	return nil
   496  }
   497  
   498  func (conn *JetstreamTriggerConn) clearDependencyIfExistsBeforeTime(depName string, beforeTime time.Time) error {
   499  	key := getDependencyKey(conn.triggerName, depName)
   500  
   501  	// first get the value (if it exists) to determine if it occurred before or after the time in question
   502  	msgInfo, found, err := conn.getSavedDependency(depName)
   503  	if err != nil {
   504  		return err
   505  	}
   506  	if found {
   507  		// determine if the dependency is from before the time in question
   508  		if msgInfo.Timestamp.Before(beforeTime) {
   509  			conn.Logger.Debugf("clearing key %s from the K/V store since its message time %+v occurred before %+v; MsgInfo:%+v",
   510  				key, msgInfo.Timestamp.Local(), beforeTime.Local(), msgInfo)
   511  			err := conn.keyValueStore.Delete(key)
   512  			if err != nil && err != nats.ErrKeyNotFound {
   513  				conn.Logger.Error(err)
   514  				return err
   515  			}
   516  		}
   517  	}
   518  
   519  	return nil
   520  }
   521  
   522  func (conn *JetstreamTriggerConn) clearDependencyIfExists(depName string) error {
   523  	key := getDependencyKey(conn.triggerName, depName)
   524  	conn.Logger.Debugf("clearing key %s from the K/V store", key)
   525  	err := conn.keyValueStore.Delete(key)
   526  	if err != nil && err != nats.ErrKeyNotFound {
   527  		conn.Logger.Error(err)
   528  		return err
   529  	}
   530  	return nil
   531  }
   532  
   533  func (conn *JetstreamTriggerConn) getDependencyNames(eventSourceName, eventName string) ([]string, error) {
   534  	deps, found := conn.sourceDepMap[eventSourceName+"__"+eventName]
   535  	if !found {
   536  		errStr := fmt.Sprintf("incoming event source and event not associated with any dependencies, event source=%s, event=%s",
   537  			eventSourceName, eventName)
   538  		conn.Logger.Error(errStr)
   539  		return nil, fmt.Errorf(errStr)
   540  	}
   541  
   542  	return deps, nil
   543  }
   544  
   545  // save the message in our recent messages list (for de-duplication purposes)
   546  func (conn *JetstreamTriggerConn) storeMessageID(id string) {
   547  	now := time.Now().UnixNano()
   548  	saveMsg := &msg{msgID: id, time: now}
   549  	conn.recentMsgsByID[id] = saveMsg
   550  	conn.recentMsgsByTime = append(conn.recentMsgsByTime, saveMsg)
   551  }
   552  
   553  func (conn *JetstreamTriggerConn) purgeOldMsgs() {
   554  	now := time.Now().UnixNano()
   555  
   556  	// evict any old messages from our message cache
   557  	for _, msg := range conn.recentMsgsByTime {
   558  		if now-msg.time > 60*1000*1000*1000 { // older than 1 minute
   559  			conn.Logger.Debugf("deleting message %v from cache", *msg)
   560  			delete(conn.recentMsgsByID, msg.msgID)
   561  			conn.recentMsgsByTime = conn.recentMsgsByTime[1:]
   562  		} else {
   563  			break // these are ordered by time so we can break when we hit one that's still valid
   564  		}
   565  	}
   566  }