github.com/Axway/agent-sdk@v1.1.101/pkg/traceability/traceability_test.go (about)

     1  package traceability
     2  
     3  import (
     4  	"compress/gzip"
     5  	"context"
     6  	"encoding/json"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"net/url"
    12  	"os"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/Axway/agent-sdk/pkg/agent"
    17  	"github.com/Axway/agent-sdk/pkg/config"
    18  	"github.com/Axway/agent-sdk/pkg/traceability/sampling"
    19  	"github.com/elastic/beats/v7/libbeat/beat"
    20  	"github.com/elastic/beats/v7/libbeat/common"
    21  	"github.com/elastic/beats/v7/libbeat/outputs"
    22  	"github.com/elastic/beats/v7/libbeat/publisher"
    23  	"github.com/stretchr/testify/assert"
    24  )
    25  
    26  var logstashClientCreateCalled = false
    27  
    28  func init() {
    29  	logstashFactory := func(
    30  		indexManager outputs.IndexManager,
    31  		beat beat.Info,
    32  		observer outputs.Observer,
    33  		cfg *common.Config,
    34  	) (outputs.Group, error) {
    35  		logstashClientCreateCalled = true
    36  		return outputs.SuccessNet(false, 1, 1, nil)
    37  	}
    38  	outputs.RegisterType("logstash", logstashFactory)
    39  }
    40  
    41  func createCentralCfg(url, env string) *config.CentralConfiguration {
    42  	cfg := config.NewCentralConfig(config.DiscoveryAgent).(*config.CentralConfiguration)
    43  	cfg.URL = url
    44  	cfg.SingleURL = ""
    45  	cfg.TenantID = "123456"
    46  	cfg.Environment = env
    47  	authCfg := cfg.Auth.(*config.AuthConfiguration)
    48  	authCfg.URL = url + "/auth"
    49  	authCfg.Realm = "Broker"
    50  	authCfg.ClientID = "serviceaccount_1234"
    51  	authCfg.PrivateKey = "../transaction/testdata/private_key.pem"
    52  	authCfg.PublicKey = "../transaction/testdata/public_key"
    53  	cfg.GetMetricReportingConfig().(*config.MetricReportingConfiguration).Schedule = "* * * * *" // every minute
    54  	cfg.GetUsageReportingConfig().(*config.UsageReportingConfiguration).Offline = false
    55  	return cfg
    56  }
    57  
    58  func createTransport(config *Config) (outputs.Group, error) {
    59  	info := beat.Info{
    60  		Beat:        "test-beat",
    61  		IndexPrefix: "",
    62  		Version:     "1.0",
    63  	}
    64  	// defcfg := DefaultConfig()
    65  	commonCfg, _ := common.NewConfigFrom(config)
    66  	return makeTraceabilityAgent(nil, info, nil, commonCfg)
    67  }
    68  
    69  func createBatch(msgValue string) *MockBatch {
    70  	return &MockBatch{
    71  		acked:      false,
    72  		retryCount: 0,
    73  		events:     createEvent(msgValue),
    74  	}
    75  }
    76  
    77  func createEvent(msgValue string) []publisher.Event {
    78  	fieldsData := common.MapStr{
    79  		"message": msgValue,
    80  	}
    81  	return []publisher.Event{
    82  		{
    83  			Content: beat.Event{
    84  				Timestamp: time.Now(),
    85  				Meta:      common.MapStr{sampling.SampleKey: true},
    86  				Private:   nil,
    87  				Fields:    fieldsData,
    88  			},
    89  		},
    90  	}
    91  }
    92  
    93  type mockHTTPServer struct {
    94  	serverMessages   []map[string]interface{}
    95  	responseStatus   int
    96  	requestUserAgent string
    97  
    98  	server *httptest.Server
    99  }
   100  
   101  func newMockHTTPServer() *mockHTTPServer {
   102  	mockServer := &mockHTTPServer{}
   103  	mockServer.server = httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
   104  		switch req.RequestURI {
   105  		case "/auth/realms/Broker/protocol/openid-connect/token":
   106  			token := "{\"access_token\":\"somevalue\",\"expires_in\": 12235677}"
   107  			resp.Write([]byte(token))
   108  		case "/":
   109  			if req.Method == "POST" {
   110  				if mockServer.responseStatus != 0 {
   111  					resp.WriteHeader(mockServer.responseStatus)
   112  					return
   113  				}
   114  				mockServer.requestUserAgent = req.Header.Get("User-Agent")
   115  				mockServer.ResetMessages()
   116  				var body []byte
   117  				contentEncoding := req.Header["Content-Encoding"]
   118  				if contentEncoding != nil && contentEncoding[0] == "gzip" {
   119  					body, _ = mockServer.decompressGzipContent(req.Body)
   120  				} else {
   121  					body, _ = ioutil.ReadAll(req.Body)
   122  				}
   123  				json.Unmarshal(body, &mockServer.serverMessages)
   124  				resp.Write([]byte("ok"))
   125  			}
   126  			resp.Write([]byte("ok"))
   127  		}
   128  	}))
   129  	return mockServer
   130  }
   131  
   132  func (s *mockHTTPServer) ResetStatus() {
   133  	s.responseStatus = 0
   134  }
   135  
   136  func (s *mockHTTPServer) ResetMessages() {
   137  	s.serverMessages = make([]map[string]interface{}, 0)
   138  }
   139  
   140  func (s *mockHTTPServer) GetMessages() []map[string]interface{} {
   141  	return s.serverMessages
   142  }
   143  
   144  func (s *mockHTTPServer) GetUserAgent() string {
   145  	return s.requestUserAgent
   146  }
   147  
   148  func (s *mockHTTPServer) Close() {
   149  	s.server.Close()
   150  }
   151  func (s *mockHTTPServer) decompressGzipContent(gzipBufferReader io.Reader) ([]byte, error) {
   152  	gzipReader, err := gzip.NewReader(gzipBufferReader)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	plainContent, err := ioutil.ReadAll(gzipReader)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	return plainContent, nil
   161  }
   162  
   163  type MockBatch struct {
   164  	acked      bool
   165  	retryCount int
   166  
   167  	events []publisher.Event
   168  }
   169  
   170  func (b *MockBatch) Events() []publisher.Event                { return b.events }
   171  func (b *MockBatch) ACK()                                     { b.acked = true }
   172  func (b *MockBatch) Drop()                                    {}
   173  func (b *MockBatch) Retry()                                   {}
   174  func (b *MockBatch) Cancelled()                               {}
   175  func (b *MockBatch) RetryEvents(events []publisher.Event)     { b.retryCount++ }
   176  func (b *MockBatch) CancelledEvents(events []publisher.Event) {}
   177  
   178  type testEventProcessor struct {
   179  	msgValue string
   180  }
   181  
   182  func (t *testEventProcessor) Process(events []publisher.Event) []publisher.Event {
   183  	return createEvent(t.msgValue)
   184  }
   185  
   186  func TestCreateLogstashClient(t *testing.T) {
   187  	s := newMockHTTPServer()
   188  	defer s.Close()
   189  
   190  	cfg := createCentralCfg(s.server.URL, "v7")
   191  	agent.Initialize(cfg)
   192  
   193  	group, err := createTransport(nil)
   194  
   195  	assert.NotNil(t, err)
   196  	assert.Contains(t, err.Error(), "config is nil")
   197  	assert.NotNil(t, group)
   198  	assert.Nil(t, group.Clients)
   199  	assert.False(t, logstashClientCreateCalled)
   200  	testConfig := DefaultConfig()
   201  
   202  	group, err = createTransport(testConfig)
   203  	assert.NotNil(t, err)
   204  	assert.Contains(t, err.Error(), "empty array accessing 'hosts'")
   205  	assert.NotNil(t, group)
   206  	assert.Nil(t, group.Clients)
   207  	assert.False(t, logstashClientCreateCalled)
   208  
   209  	testConfig.Hosts = []string{
   210  		"somehost",
   211  		"someotherhost",
   212  	}
   213  	group, err = createTransport(testConfig)
   214  	assert.Nil(t, err)
   215  	assert.NotNil(t, group)
   216  	assert.NotNil(t, group.Clients)
   217  	assert.True(t, logstashClientCreateCalled)
   218  
   219  	testConfig.Pipelining = 5
   220  	testConfig.Hosts = []string{
   221  		"somehost2",
   222  	}
   223  	group, err = createTransport(testConfig)
   224  	assert.Nil(t, err)
   225  	assert.NotNil(t, group)
   226  	assert.True(t, logstashClientCreateCalled)
   227  	traceabilityClient := group.Clients[0].(*Client)
   228  	assert.NotNil(t, traceabilityClient)
   229  	assert.False(t, IsHTTPTransport())
   230  	assert.Equal(t, 3, GetMaxRetries())
   231  }
   232  
   233  func TestCreateLogstashClientWithSingleEntry(t *testing.T) {
   234  	cfg := createCentralCfg("http://localhost:8888", "v7")
   235  	cfg.SingleURL = "http://localhost:9999"
   236  	agent.Initialize(cfg)
   237  	logstashClientCreateCalled = false
   238  
   239  	testConfig := DefaultConfig()
   240  	testConfig.Protocol = "http"
   241  	testConfig.Hosts = []string{
   242  		"somehost",
   243  	}
   244  	group, err := createTransport(testConfig)
   245  	assert.Nil(t, err)
   246  	assert.NotNil(t, group)
   247  	assert.NotNil(t, group.Clients)
   248  	assert.True(t, logstashClientCreateCalled)
   249  	assert.Equal(t, "tcp", traceCfg.Protocol)
   250  	transportProxy := os.Getenv("TRACEABILITY_PROXYURL")
   251  	assert.Equal(t, "sni://"+traceCfg.Hosts[0], transportProxy)
   252  
   253  	testConfig.Proxy = ProxyConfig{
   254  		URL:          "http://localhost:9999",
   255  		LocalResolve: false,
   256  	}
   257  
   258  	testConfig.Hosts = []string{
   259  		"somehost",
   260  	}
   261  	group, err = createTransport(testConfig)
   262  	assert.Nil(t, err)
   263  	assert.NotNil(t, group)
   264  	assert.NotNil(t, group.Clients)
   265  	assert.True(t, logstashClientCreateCalled)
   266  	assert.Equal(t, "tcp", traceCfg.Protocol)
   267  	assert.Equal(t, "http://localhost:9999", traceCfg.Proxy.URL)
   268  	transportProxy = os.Getenv("TRACEABILITY_PROXYURL")
   269  	assert.Equal(t, "sni://"+traceCfg.Hosts[0], transportProxy)
   270  }
   271  
   272  func TestCreateHTTPClient(t *testing.T) {
   273  	logstashClientCreateCalled = false
   274  	cfg := createCentralCfg("http://localhost:8888", "v7")
   275  	agent.Initialize(cfg)
   276  
   277  	testConfig := DefaultConfig()
   278  	testConfig.Protocol = "http"
   279  
   280  	testConfig.Hosts = []string{
   281  		"somehost:invalidport",
   282  	}
   283  
   284  	group, err := createTransport(testConfig)
   285  	assert.NotNil(t, err)
   286  	assert.Contains(t, err.Error(), "invalid port")
   287  	assert.NotNil(t, group)
   288  	assert.Nil(t, group.Clients)
   289  	assert.False(t, logstashClientCreateCalled)
   290  
   291  	testConfig.Hosts = []string{
   292  		"somehost",
   293  	}
   294  	testConfig.Proxy = ProxyConfig{
   295  		URL: "bogus\\:bogus",
   296  	}
   297  
   298  	group, err = createTransport(testConfig)
   299  	assert.NotNil(t, err)
   300  	assert.NotNil(t, group)
   301  	assert.Nil(t, group.Clients)
   302  	assert.False(t, logstashClientCreateCalled)
   303  
   304  	testConfig.Proxy = ProxyConfig{}
   305  	testConfig.CompressionLevel = 20
   306  	group, err = createTransport(testConfig)
   307  	assert.NotNil(t, err)
   308  	assert.Contains(t, err.Error(), "requires value <= 9 accessing 'compression_level'")
   309  	assert.NotNil(t, group)
   310  	assert.Nil(t, group.Clients)
   311  	assert.False(t, logstashClientCreateCalled)
   312  
   313  	testConfig.CompressionLevel = 0
   314  	group, err = createTransport(testConfig)
   315  	assert.Nil(t, err)
   316  	assert.NotNil(t, group)
   317  	assert.Equal(t, 1, len(group.Clients))
   318  	traceabilityClient := group.Clients[0].(*Client)
   319  	assert.NotNil(t, traceabilityClient)
   320  	assert.False(t, logstashClientCreateCalled)
   321  	assert.True(t, IsHTTPTransport())
   322  	assert.Equal(t, 3, GetMaxRetries())
   323  }
   324  
   325  func TestHTTPTransportWithJSONEncoding(t *testing.T) {
   326  	s := newMockHTTPServer()
   327  	defer s.Close()
   328  	config.AgentTypeName = "TraceabilityAgent"
   329  	config.AgentVersion = "0.0.1-abc"
   330  	config.SDKVersion = "0.0.1"
   331  
   332  	cfg := createCentralCfg(s.server.URL, "v7")
   333  	agent.Initialize(cfg)
   334  
   335  	url, _ := url.Parse(s.server.URL)
   336  	testConfig := DefaultConfig()
   337  	testConfig.Protocol = "http"
   338  	testConfig.CompressionLevel = 0
   339  	testConfig.Hosts = []string{url.Hostname() + ":" + url.Port()}
   340  
   341  	group, err := createTransport(testConfig)
   342  	assert.Nil(t, err)
   343  	assert.NotNil(t, group)
   344  	traceabilityClient := group.Clients[0].(*Client)
   345  	batch := createBatch("{\"f1\":\"test\"}")
   346  	traceabilityClient.Connect()
   347  	agent.StartAgentStatusUpdate()
   348  	err = traceabilityClient.Publish(context.Background(), batch)
   349  	traceabilityClient.Close()
   350  
   351  	assert.Nil(t, err)
   352  	publishedMessages := s.GetMessages()
   353  	reqUA := s.GetUserAgent()
   354  	assert.NotEmpty(t, reqUA)
   355  	assert.NotNil(t, publishedMessages)
   356  	assert.Equal(t, 1, len(publishedMessages))
   357  	event := publishedMessages[0]
   358  	assert.Nil(t, err)
   359  	assert.Equal(t, "test", event["f1"])
   360  	assert.True(t, batch.acked)
   361  }
   362  
   363  func TestHTTPTransportWithOutputProcessor(t *testing.T) {
   364  	s := newMockHTTPServer()
   365  	defer s.Close()
   366  
   367  	cfg := createCentralCfg(s.server.URL, "v7")
   368  	agent.Initialize(cfg)
   369  
   370  	url, _ := url.Parse(s.server.URL)
   371  	testConfig := DefaultConfig()
   372  	testConfig.Protocol = "http"
   373  	testConfig.CompressionLevel = 0
   374  	testConfig.Hosts = []string{
   375  		url.Hostname() + ":" + url.Port(),
   376  	}
   377  
   378  	eventProcessor := &testEventProcessor{msgValue: "{\"f1\":\"test\"}"}
   379  	SetOutputEventProcessor(eventProcessor)
   380  	group, err := createTransport(testConfig)
   381  	assert.Nil(t, err)
   382  	traceabilityClient := group.Clients[0].(*Client)
   383  	batch := createBatch("{\"f0\":\"dummy\"}")
   384  
   385  	traceabilityClient.Connect()
   386  	agent.StartAgentStatusUpdate()
   387  	err = traceabilityClient.Publish(context.Background(), batch)
   388  	traceabilityClient.Close()
   389  	assert.Nil(t, err)
   390  
   391  	publishedMessages := s.GetMessages()
   392  	assert.NotNil(t, publishedMessages)
   393  	assert.Equal(t, 1, len(publishedMessages))
   394  	event := publishedMessages[0]
   395  	assert.Equal(t, "test", event["f1"])
   396  	assert.Nil(t, event["f0"])
   397  	assert.True(t, batch.acked)
   398  
   399  	SetOutputEventProcessor(nil)
   400  }
   401  
   402  func TestHTTPTransportWithGzipEncoding(t *testing.T) {
   403  	s := newMockHTTPServer()
   404  	defer s.Close()
   405  
   406  	cfg := createCentralCfg(s.server.URL, "v7")
   407  	agent.Initialize(cfg)
   408  
   409  	url, _ := url.Parse(s.server.URL)
   410  	testConfig := DefaultConfig()
   411  	testConfig.Protocol = "http"
   412  	testConfig.CompressionLevel = 3
   413  	testConfig.Hosts = []string{
   414  		url.Hostname() + ":" + url.Port(),
   415  	}
   416  
   417  	group, err := createTransport(testConfig)
   418  	assert.Nil(t, err)
   419  	assert.NotNil(t, group)
   420  	traceabilityClient := group.Clients[0].(*Client)
   421  	batch := createBatch("{\"f1\":\"test\"}")
   422  
   423  	traceabilityClient.Connect()
   424  	err = traceabilityClient.Publish(context.Background(), batch)
   425  	assert.Nil(t, err)
   426  	traceabilityClient.Close()
   427  
   428  	publishedMessages := s.GetMessages()
   429  	assert.NotNil(t, publishedMessages)
   430  	assert.Equal(t, 1, len(publishedMessages))
   431  
   432  	event := publishedMessages[0]
   433  
   434  	assert.Nil(t, err)
   435  	assert.Equal(t, "test", event["f1"])
   436  	assert.True(t, batch.acked)
   437  }
   438  
   439  func TestHTTPTransportRetries(t *testing.T) {
   440  	s := newMockHTTPServer()
   441  	defer s.Close()
   442  
   443  	cfg := createCentralCfg(s.server.URL, "v7")
   444  	agent.Initialize(cfg)
   445  
   446  	url, _ := url.Parse(s.server.URL)
   447  	testConfig := DefaultConfig()
   448  	testConfig.Protocol = "http"
   449  	testConfig.CompressionLevel = 0
   450  	testConfig.Hosts = []string{
   451  		url.Hostname() + ":" + url.Port(),
   452  	}
   453  
   454  	group, err := createTransport(testConfig)
   455  	assert.Nil(t, err)
   456  	traceabilityClient := group.Clients[0].(*Client)
   457  	batch := createBatch("somemessage")
   458  
   459  	s.responseStatus = 404
   460  	traceabilityClient.Connect()
   461  	err = traceabilityClient.Publish(context.Background(), batch)
   462  	traceabilityClient.Close()
   463  	assert.NotNil(t, err)
   464  	assert.False(t, batch.acked)
   465  	assert.Equal(t, 1, batch.retryCount)
   466  
   467  	s.responseStatus = 500
   468  	batch = createBatch("somemessage")
   469  	group, err = createTransport(testConfig)
   470  	assert.Nil(t, err)
   471  
   472  	traceabilityClient = group.Clients[0].(*Client)
   473  	traceabilityClient.Connect()
   474  	err = traceabilityClient.Publish(context.Background(), batch)
   475  	traceabilityClient.Close()
   476  	assert.NotNil(t, err)
   477  	assert.False(t, batch.acked)
   478  	assert.Equal(t, 1, batch.retryCount)
   479  	publishedMessages := s.GetMessages()
   480  	assert.Nil(t, publishedMessages)
   481  
   482  	SetOutputEventProcessor(nil)
   483  }