github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/cmd/runner/rabbit.go (about)

     1  // Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.
     2  
     3  package main
     4  
     5  import (
     6  	"context"
     7  	"net/url"
     8  	"os"
     9  	"regexp"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/leaf-ai/studio-go-runner/internal/runner"
    14  	"github.com/leaf-ai/studio-go-runner/internal/types"
    15  
    16  	"github.com/go-stack/stack"
    17  	"github.com/jjeffery/kv" // MIT License
    18  
    19  	"github.com/prometheus/client_golang/prometheus"
    20  )
    21  
    22  var (
    23  	wrapperFailSeen = false
    24  )
    25  
    26  // This file contains the implementation of a RabbitMQ service for
    27  // retrieving and handling StudioML workloads within a self hosted
    28  // queue context
    29  
    30  func initRMQ() (rmq *runner.RabbitMQ) {
    31  	// NewRabbitMQ takes a URL that has no credentials or tokens attached as the
    32  	// first parameter and the user name password as the second parameter
    33  	creds := ""
    34  	qURL, errGo := url.Parse(os.ExpandEnv(*amqpURL))
    35  	if errGo != nil {
    36  		logger.Warn(kv.Wrap(errGo).With("url", *amqpURL).With("stack", stack.Trace().TrimRuntime()).Error())
    37  	}
    38  	if qURL.User != nil {
    39  		creds = qURL.User.String()
    40  	} else {
    41  		logger.Warn(kv.NewError("missing credentials in url").With("url", *amqpURL).With("stack", stack.Trace().TrimRuntime()).Error())
    42  	}
    43  	qURL.User = nil
    44  
    45  	w, err := getWrapper()
    46  	if err != nil {
    47  		if !wrapperFailSeen {
    48  			logger.Warn(err.Error(), "stack", stack.Trace().TrimRuntime())
    49  			wrapperFailSeen = true
    50  		}
    51  	}
    52  
    53  	rmqRef, err := runner.NewRabbitMQ(qURL.String(), creds, w)
    54  	if err != nil {
    55  		logger.Warn(err.Error(), "stack", stack.Trace().TrimRuntime())
    56  	}
    57  	return rmqRef
    58  }
    59  
    60  func initRMQStructs() (matcher *regexp.Regexp, mismatcher *regexp.Regexp) {
    61  
    62  	// The regular expression is validated in the main.go file
    63  	matcher, errGo := regexp.Compile(*queueMatch)
    64  	if errGo != nil {
    65  		if len(*queueMatch) != 0 {
    66  			logger.Warn(kv.Wrap(errGo).With("matcher", *queueMatch).With("stack", stack.Trace().TrimRuntime()).Error())
    67  		}
    68  		matcher = nil
    69  	}
    70  
    71  	// If the length of the mismatcher is 0 then we will get a nil and because this
    72  	// was checked in the main we can ignore that as this is optional
    73  
    74  	if len(strings.Trim(*queueMismatch, " \n\r\t")) == 0 {
    75  		mismatcher = nil
    76  	} else {
    77  		mismatcher, errGo = regexp.Compile(*queueMismatch)
    78  		if errGo != nil {
    79  			if len(*queueMismatch) != 0 {
    80  				logger.Warn(kv.Wrap(errGo).With("mismatcher", *queueMismatch).With("stack", stack.Trace().TrimRuntime()).Error())
    81  			}
    82  			mismatcher = nil
    83  		}
    84  	}
    85  	return matcher, mismatcher
    86  }
    87  
    88  // serviceRMQ runs for the lifetime of the daemon and uses the ctx to perform orderly shutdowns
    89  //
    90  func serviceRMQ(ctx context.Context, checkInterval time.Duration, connTimeout time.Duration) {
    91  
    92  	logger.Debug("starting serviceRMQ", stack.Trace().TrimRuntime())
    93  	defer logger.Debug("stopping serviceRMQ", stack.Trace().TrimRuntime())
    94  
    95  	if len(*amqpURL) == 0 {
    96  		logger.Info("rabbitMQ services disabled", stack.Trace().TrimRuntime())
    97  		return
    98  	}
    99  
   100  	matcher, mismatcher := initRMQStructs()
   101  	rmq := initRMQ()
   102  
   103  	// Tracks all known queues and their cancel functions so they can have any
   104  	// running jobs terminated should they disappear
   105  	live := &Projects{
   106  		queueType: "rabbitMQ",
   107  		projects:  map[string]context.CancelFunc{},
   108  	}
   109  
   110  	lifecycleC := make(chan runner.K8sStateUpdate, 1)
   111  	id, err := k8sStateUpdates().Add(lifecycleC)
   112  	if err != nil {
   113  		logger.Warn(err.With("stack", stack.Trace().TrimRuntime()).Error())
   114  	}
   115  
   116  	defer func() {
   117  		// Ignore failures to cleanup resources we will never reuse
   118  		func() {
   119  			defer func() {
   120  				_ = recover()
   121  			}()
   122  			k8sStateUpdates().Delete(id)
   123  		}()
   124  		close(lifecycleC)
   125  	}()
   126  
   127  	host, errGo := os.Hostname()
   128  	if errGo != nil {
   129  		logger.Warn(errGo.Error())
   130  	}
   131  
   132  	// first time through make sure the credentials are checked immediately
   133  	qCheck := time.Duration(time.Second)
   134  	currentCheck := qCheck
   135  	qTicker := time.NewTicker(currentCheck)
   136  	defer qTicker.Stop()
   137  
   138  	// Watch for when the server should not be getting new work
   139  	state := runner.K8sStateUpdate{
   140  		State: types.K8sRunning,
   141  	}
   142  	for {
   143  		// Dont wait an excessive amount of time after server checks fail before
   144  		// retrying
   145  		if qCheck > time.Duration(3*time.Minute) {
   146  			qCheck = time.Duration(3 * time.Minute)
   147  		}
   148  
   149  		// If the interval between queue checks changes reset the ticker
   150  
   151  		if qCheck != currentCheck {
   152  			currentCheck = qCheck
   153  			qTicker.Stop()
   154  			qTicker = time.NewTicker(currentCheck)
   155  		}
   156  
   157  		select {
   158  		case <-ctx.Done():
   159  			live.Lock()
   160  			defer live.Unlock()
   161  
   162  			// When shutting down stop all projects
   163  			for _, quiter := range live.projects {
   164  				if quiter != nil {
   165  					quiter()
   166  				}
   167  			}
   168  			logger.Debug("quitC done for serviceRMQ", stack.Trace().TrimRuntime())
   169  			return
   170  		case state = <-lifecycleC:
   171  		case <-qTicker.C:
   172  
   173  			qCheck = checkInterval
   174  
   175  			// If the pulling of work is currently suspending bail out of checking the queues
   176  			if state.State != types.K8sRunning && state.State != types.K8sUnknown {
   177  				queueIgnored.With(prometheus.Labels{"host": host, "queue_type": live.queueType, "queue_name": "*"}).Inc()
   178  				logger.Trace("k8s has RMQ disabled", "stack", stack.Trace().TrimRuntime())
   179  				continue
   180  			}
   181  
   182  			connCtx, cancel := context.WithTimeout(ctx, connTimeout)
   183  
   184  			// Found returns a map that contains the queues that were found
   185  			// on the rabbitMQ server specified by the rmq data structure
   186  			found, err := rmq.GetKnown(connCtx, matcher, mismatcher)
   187  			cancel()
   188  
   189  			if err != nil {
   190  				qCheck = qCheck * 2
   191  				err = err.With("backoff", qCheck.String())
   192  				logger.Warn("unable to refresh RMQ manifest", err.Error())
   193  				continue
   194  			}
   195  			if len(found) == 0 {
   196  				items := []string{"no queues found", "identity", rmq.Identity, "matcher", matcher.String()}
   197  
   198  				if mismatcher != nil {
   199  					items = append(items, "mismatcher", mismatcher.String())
   200  				}
   201  				items = append(items, "stack", stack.Trace().TrimRuntime().String())
   202  				logger.Warn(items[0], items[1:])
   203  
   204  				qCheck = qCheck * 2
   205  				continue
   206  			}
   207  
   208  			// Found rneeds to just have the main queue servers as their keys, individual queues will be treated as subscriptions
   209  
   210  			filtered := make(map[string]string, len(found))
   211  			for k, v := range found {
   212  				qItems := strings.Split(k, "?")
   213  				filtered[qItems[0]] = v
   214  			}
   215  
   216  			// found contains a map of keys that have an uncredentialed URL, and the value which is the user name and password for the URL
   217  			//
   218  			// The URL path is going to be the vhost and the queue name
   219  			if err := live.Lifecycle(ctx, filtered); err != nil {
   220  				logger.Warn(err.Error())
   221  			}
   222  		}
   223  	}
   224  }