github.com/Axway/agent-sdk@v1.1.101/pkg/traceability/traceability.go (about)

     1  package traceability
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"net/url"
     8  	"os"
     9  	"path"
    10  	"reflect"
    11  	"sync"
    12  	"unsafe"
    13  
    14  	"github.com/Axway/agent-sdk/pkg/agent"
    15  	"github.com/Axway/agent-sdk/pkg/jobs"
    16  	"github.com/Axway/agent-sdk/pkg/traceability/sampling"
    17  	"github.com/Axway/agent-sdk/pkg/util"
    18  	"github.com/Axway/agent-sdk/pkg/util/log"
    19  	"github.com/elastic/beats/v7/libbeat/beat"
    20  	"github.com/elastic/beats/v7/libbeat/common"
    21  	"github.com/elastic/beats/v7/libbeat/common/transport"
    22  	"github.com/elastic/beats/v7/libbeat/common/transport/tlscommon"
    23  	"github.com/elastic/beats/v7/libbeat/outputs"
    24  	"github.com/elastic/beats/v7/libbeat/paths"
    25  	"github.com/elastic/beats/v7/libbeat/publisher"
    26  	"golang.org/x/net/proxy"
    27  
    28  	hc "github.com/Axway/agent-sdk/pkg/util/healthcheck"
    29  )
    30  
    31  const (
    32  	countStr     = "count"
    33  	eventTypeStr = "event-type"
    34  )
    35  
    36  // OutputEventProcessor - P
    37  type OutputEventProcessor interface {
    38  	Process(events []publisher.Event) []publisher.Event
    39  }
    40  
    41  var outputEventProcessor OutputEventProcessor
    42  var pathDataMutex sync.Mutex = sync.Mutex{}
    43  
    44  const (
    45  	minWindowSize             int = 1
    46  	defaultStartMaxWindowSize int = 10
    47  	defaultPort                   = 5044
    48  	traceabilityStr               = "traceability"
    49  	HealthCheckEndpoint           = traceabilityStr
    50  )
    51  
    52  var traceabilityClients []*Client
    53  var clientMutex *sync.Mutex
    54  var traceCfg *Config
    55  
    56  // GetClient - returns a random client from the clients array
    57  var GetClient = getClient
    58  
    59  func addClient(c *Client) {
    60  	clientMutex.Lock()
    61  	defer clientMutex.Unlock()
    62  	traceabilityClients = append(traceabilityClients, c)
    63  }
    64  
    65  func getClient() (*Client, error) {
    66  	clientMutex.Lock()
    67  	defer clientMutex.Unlock()
    68  	switch clients := len(traceabilityClients); clients {
    69  	case 0:
    70  		return nil, fmt.Errorf("no traceability clients, can't publish metrics")
    71  	case 1:
    72  		return traceabilityClients[0], nil
    73  	default:
    74  		randomIndex := rand.Intn(len(traceabilityClients))
    75  		return traceabilityClients[randomIndex], nil
    76  	}
    77  }
    78  
    79  // Client - struct
    80  type Client struct {
    81  	sync.Mutex
    82  	transportClient outputs.Client
    83  	logger          log.FieldLogger
    84  }
    85  
    86  func init() {
    87  	clientMutex = &sync.Mutex{}
    88  	outputs.RegisterType(traceabilityStr, makeTraceabilityAgent)
    89  }
    90  
    91  // SetOutputEventProcessor -
    92  func SetOutputEventProcessor(eventProcessor OutputEventProcessor) {
    93  	outputEventProcessor = eventProcessor
    94  }
    95  
    96  // GetDataDirPath - Returns the path of the data directory
    97  func GetDataDirPath() string {
    98  	pathDataMutex.Lock()
    99  	defer pathDataMutex.Unlock()
   100  	return paths.Paths.Data
   101  }
   102  
   103  // SetDataDirPath - Sets the path of the data directory
   104  func SetDataDirPath(path string) {
   105  	pathDataMutex.Lock()
   106  	defer pathDataMutex.Unlock()
   107  	paths.Paths.Data = path
   108  }
   109  
   110  // checkCreateDir
   111  func createDirIfNotExist(dirPath string) {
   112  	_, err := os.Stat(dirPath)
   113  	if os.IsNotExist(err) {
   114  		// Create the directory with the same permissions as the data dir
   115  		dataInfo, _ := os.Stat(GetDataDirPath())
   116  		os.MkdirAll(dirPath, dataInfo.Mode().Perm())
   117  	}
   118  }
   119  
   120  // GetCacheDirPath - Returns the path of the cache directory
   121  func GetCacheDirPath() string {
   122  	cacheDir := path.Join(GetDataDirPath(), "cache")
   123  	createDirIfNotExist(cacheDir)
   124  	return cacheDir
   125  }
   126  
   127  // GetReportsDirPath - Returns the path of the reports directory
   128  func GetReportsDirPath() string {
   129  	reportDir := path.Join(GetDataDirPath(), "reports")
   130  	createDirIfNotExist(reportDir)
   131  	return reportDir
   132  }
   133  
   134  func makeTraceabilityAgent(
   135  	indexManager outputs.IndexManager,
   136  	beat beat.Info,
   137  	observer outputs.Observer,
   138  	libbeatCfg *common.Config,
   139  ) (outputs.Group, error) {
   140  	logger := log.NewFieldLogger().
   141  		WithPackage("sdk.traceability").
   142  		WithComponent("makeTraceabilityAgent")
   143  
   144  	var err error
   145  
   146  	logger.Trace("reading config")
   147  	traceCfg, err = readConfig(libbeatCfg, beat)
   148  
   149  	defer func() {
   150  		if err != nil {
   151  			// skip hc register if err hit making agent
   152  			return
   153  		}
   154  
   155  		if !agent.GetCentralConfig().GetUsageReportingConfig().IsOfflineMode() && util.IsNotTest() {
   156  			err := registerHealthCheckers(traceCfg)
   157  			if err != nil {
   158  				logger.WithError(err).Error("could not register healthcheck")
   159  			}
   160  		}
   161  	}()
   162  
   163  	if err != nil {
   164  		agent.UpdateStatusWithPrevious(agent.AgentFailed, agent.AgentRunning, err.Error())
   165  		logger.WithError(err).Error("reading config")
   166  		return outputs.Fail(err)
   167  	}
   168  	logger = logger.WithField("config", traceCfg)
   169  
   170  	if err := libbeatCfg.Merge(HostConfig{Hosts: traceCfg.Hosts, Protocol: traceCfg.Protocol}); err != nil {
   171  		agent.UpdateStatusWithPrevious(agent.AgentFailed, agent.AgentRunning, err.Error())
   172  		logger.WithError(err).Error("merging host config")
   173  		return outputs.Fail(err)
   174  	}
   175  
   176  	hosts, err := outputs.ReadHostList(libbeatCfg)
   177  
   178  	if err != nil {
   179  		agent.UpdateStatusWithPrevious(agent.AgentFailed, agent.AgentRunning, err.Error())
   180  		logger.WithError(err).Error("reading hosts")
   181  		return outputs.Fail(err)
   182  	}
   183  	logger = logger.WithField("hosts", hosts)
   184  
   185  	var transportGroup outputs.Group
   186  	logger.Tracef("initializing traceability client")
   187  	isSingleEntry := agent.GetCentralConfig().GetSingleURL() != ""
   188  	if !isSingleEntry && IsHTTPTransport() {
   189  		transportGroup, err = makeHTTPClient(beat, observer, traceCfg, hosts, agent.GetUserAgent())
   190  	} else {
   191  		// For Single entry point register dialer factory for sni scheme and set the
   192  		// proxy url with sni scheme. When libbeat will register its dialer and sees
   193  		// proxy url with sni scheme, it will invoke the factory to construct the dialer
   194  		// The dialer will be invoked as proxy dialer in the libbeat dialer chain
   195  		// (proxy dialer, stat dialer, tls dialer).
   196  		if isSingleEntry {
   197  			if IsHTTPTransport() {
   198  				traceCfg.Protocol = "tcp"
   199  				logger.Warn("switching to tcp protocol instead of http because agent will use single entry endpoint")
   200  			}
   201  			// Register dialer factory with sni scheme for single entry point
   202  			proxy.RegisterDialerType("sni", ingestionSingleEntryDialer)
   203  			// If real proxy configured(not the sni proxy set here), validate the scheme
   204  			// since libbeats proxy dialer will not be invoked.
   205  			if traceCfg.Proxy.URL != "" {
   206  				proxCfg := &transport.ProxyConfig{
   207  					URL:          traceCfg.Proxy.URL,
   208  					LocalResolve: traceCfg.Proxy.LocalResolve,
   209  				}
   210  				err := proxCfg.Validate()
   211  				if err != nil {
   212  					logger.WithError(err).Error("validating proxy config")
   213  					outputs.Fail(err)
   214  				}
   215  			}
   216  			// Replace the proxy URL to sni by setting the environment variable
   217  			// Libbeat parses the yaml file and replaces the value from yaml
   218  			// with overridden environment variable.
   219  			// Set the sni host to the ingestion service host to allow the
   220  			// single entry dialer to receive the target address
   221  			os.Setenv("TRACEABILITY_PROXYURL", "sni://"+traceCfg.Hosts[0])
   222  		}
   223  		transportGroup, err = makeLogstashClient(indexManager, beat, observer, libbeatCfg)
   224  	}
   225  
   226  	if err != nil {
   227  		logger.WithError(err).Error("creating traceability client")
   228  		return outputs.Fail(err)
   229  	}
   230  
   231  	traceabilityGroup := outputs.Group{
   232  		BatchSize: transportGroup.BatchSize,
   233  		Retry:     transportGroup.Retry,
   234  	}
   235  	clients := make([]outputs.Client, 0)
   236  
   237  	for _, client := range transportGroup.Clients {
   238  		outputClient := &Client{
   239  			transportClient: client,
   240  			logger:          logger.WithComponent("traceabilityClient").WithPackage("sdk.traceability"),
   241  		}
   242  		clients = append(clients, outputClient)
   243  		addClient(outputClient)
   244  	}
   245  	traceabilityGroup.Clients = clients
   246  	return traceabilityGroup, nil
   247  }
   248  
   249  func makeLogstashClient(indexManager outputs.IndexManager,
   250  	beat beat.Info,
   251  	observer outputs.Observer,
   252  	libbeatCfg *common.Config,
   253  ) (outputs.Group, error) {
   254  	factory := outputs.FindFactory("logstash")
   255  	if factory == nil {
   256  		return outputs.Group{}, nil
   257  	}
   258  	group, err := factory(indexManager, beat, observer, libbeatCfg)
   259  	return group, err
   260  }
   261  
   262  // Factory method for creating dialer for sni scheme
   263  // Setup the single entry point dialer with single entry host mapping based
   264  // on central config and traceability proxy url from original config that gets
   265  // read by traceability output factory(makeTraceabilityAgent)
   266  func ingestionSingleEntryDialer(proxyURL *url.URL, parentDialer proxy.Dialer) (proxy.Dialer, error) {
   267  	var traceProxyURL *url.URL
   268  	var err error
   269  	if traceCfg != nil && traceCfg.Proxy.URL != "" {
   270  		traceProxyURL, err = url.Parse(traceCfg.Proxy.URL)
   271  		if err != nil {
   272  			return nil, fmt.Errorf("proxy could not be parsed. %s", err.Error())
   273  		}
   274  	}
   275  	var singleEntryHostMap map[string]string
   276  	if agent.GetCentralConfig() != nil {
   277  		cfgSingleURL := agent.GetCentralConfig().GetSingleURL()
   278  		if cfgSingleURL != "" {
   279  			// cfgSingleURL should not be empty as the factory method is registered based on that check
   280  			singleEntryURL, err := url.Parse(cfgSingleURL)
   281  			if err == nil && traceCfg != nil {
   282  				singleEntryHostMap = map[string]string{
   283  					traceCfg.Hosts[0]: util.ParseAddr(singleEntryURL),
   284  				}
   285  			}
   286  		}
   287  	}
   288  
   289  	dialer := util.NewDialer(traceProxyURL, singleEntryHostMap)
   290  	return dialer, nil
   291  }
   292  
   293  func makeHTTPClient(beat beat.Info, observer outputs.Observer, traceCfg *Config, hosts []string, userAgent string) (outputs.Group, error) {
   294  	tls, err := tlscommon.LoadTLSConfig(traceCfg.TLS)
   295  	if err != nil {
   296  		agent.UpdateStatusWithPrevious(agent.AgentFailed, agent.AgentRunning, err.Error())
   297  		return outputs.Fail(err)
   298  	}
   299  
   300  	clients := make([]outputs.NetworkClient, len(hosts))
   301  	for i, host := range hosts {
   302  		hostURL, err := common.MakeURL(traceCfg.Protocol, "/", host, 443)
   303  		if err != nil {
   304  			return outputs.Fail(err)
   305  		}
   306  		proxyURL, err := url.Parse(traceCfg.Proxy.URL)
   307  		if err != nil {
   308  			return outputs.Fail(err)
   309  		}
   310  		var client outputs.NetworkClient
   311  		client, err = NewHTTPClient(HTTPClientSettings{
   312  			BeatInfo:         beat,
   313  			URL:              hostURL,
   314  			Proxy:            proxyURL,
   315  			TLS:              tls,
   316  			Timeout:          traceCfg.Timeout,
   317  			CompressionLevel: traceCfg.CompressionLevel,
   318  			Observer:         observer,
   319  			UserAgent:        userAgent,
   320  		})
   321  
   322  		if err != nil {
   323  			return outputs.Fail(err)
   324  		}
   325  		client = outputs.WithBackoff(client, traceCfg.Backoff.Init, traceCfg.Backoff.Max)
   326  		clients[i] = client
   327  	}
   328  
   329  	return outputs.SuccessNet(traceCfg.LoadBalance, traceCfg.BulkMaxSize, traceCfg.MaxRetries, clients)
   330  }
   331  
   332  // SetTransportClient - set the transport client
   333  func (client *Client) SetTransportClient(outputClient outputs.Client) {
   334  	client.Lock()
   335  	defer client.Unlock()
   336  	client.transportClient = outputClient
   337  }
   338  
   339  // SetTransportClient - set the transport client
   340  func (client *Client) getTransportClient() outputs.Client {
   341  	client.Lock()
   342  	defer client.Unlock()
   343  	return client.transportClient
   344  }
   345  
   346  // SetLogger - set the logger
   347  func (client *Client) SetLogger(logger log.FieldLogger) {
   348  	client.logger = logger
   349  }
   350  
   351  // Connect establishes a connection to the clients sink.
   352  func (client *Client) Connect() error {
   353  	// do not attempt to establish a connection in offline mode
   354  	if agent.GetCentralConfig().GetUsageReportingConfig().IsOfflineMode() {
   355  		return nil
   356  	}
   357  
   358  	networkClient := client.getTransportClient().(outputs.NetworkClient)
   359  	err := networkClient.Connect()
   360  	if err != nil {
   361  		return err
   362  	}
   363  	return nil
   364  }
   365  
   366  // Close publish a single event to output.
   367  func (client *Client) Close() error {
   368  	// do not attempt to close a connection in offline mode, it was never established
   369  	if agent.GetCentralConfig().GetUsageReportingConfig().IsOfflineMode() {
   370  		return nil
   371  	}
   372  
   373  	err := client.getTransportClient().Close()
   374  	if err != nil {
   375  		return err
   376  	}
   377  	return nil
   378  }
   379  
   380  // Publish sends events to the clients sink.
   381  func (client *Client) Publish(ctx context.Context, batch publisher.Batch) error {
   382  	events := batch.Events()
   383  	if len(events) == 0 {
   384  		batch.ACK()
   385  		return nil // nothing to do
   386  	}
   387  	_, isMetric := events[0].Content.Meta["metric"]
   388  
   389  	if agent.GetCentralConfig().GetUsageReportingConfig().IsOfflineMode() {
   390  		if outputEventProcessor != nil && !isMetric {
   391  			outputEventProcessor.Process(events)
   392  		}
   393  		batch.ACK()
   394  		return nil
   395  	}
   396  
   397  	logger := client.logger.WithField(eventTypeStr, "metric")
   398  
   399  	if !isMetric {
   400  		logger = logger.WithField(eventTypeStr, "transaction")
   401  		if outputEventProcessor != nil {
   402  			updatedEvents := outputEventProcessor.Process(events)
   403  			updateEvent(batch, updatedEvents)
   404  		}
   405  
   406  		sampledEvents, err := sampling.FilterEvents(batch.Events())
   407  		if err != nil {
   408  			logger.Error(err.Error())
   409  		}
   410  		updateEvent(batch, sampledEvents)
   411  	}
   412  
   413  	events = batch.Events()
   414  	if len(events) == 0 {
   415  		batch.ACK()
   416  		return nil // nothing to do
   417  	}
   418  
   419  	logger = logger.WithField(countStr, len(events))
   420  	logger.Info("publishing events")
   421  
   422  	err := client.getTransportClient().Publish(ctx, batch)
   423  	if err != nil {
   424  		logger.WithError(err).Error("failed to publish events")
   425  		return err
   426  	}
   427  
   428  	logger.Info("published events")
   429  
   430  	return nil
   431  }
   432  
   433  func (client *Client) String() string {
   434  	return traceabilityStr
   435  }
   436  
   437  // updateEvent - updates the private field events in publisher.Batch
   438  func updateEvent(batch publisher.Batch, events []publisher.Event) {
   439  	pointerVal := reflect.ValueOf(batch)
   440  	val := reflect.Indirect(pointerVal)
   441  
   442  	member := val.FieldByName("events")
   443  	ptrToEvents := unsafe.Pointer(member.UnsafeAddr())
   444  	realPtrToEvents := (*[]publisher.Event)(ptrToEvents)
   445  	*realPtrToEvents = events
   446  }
   447  
   448  func registerHealthCheckers(config *Config) error {
   449  	hcJob := newTraceabilityHealthCheckJob()
   450  
   451  	_, err := jobs.RegisterIntervalJobWithName(hcJob, config.Timeout, "Traceability Health Check")
   452  	if err != nil {
   453  		return err
   454  	}
   455  
   456  	_, err = hc.RegisterHealthcheck("Traceability Agent", HealthCheckEndpoint, hcJob.healthcheck)
   457  	if err != nil {
   458  		return err
   459  	}
   460  	return nil
   461  }