github.com/Axway/agent-sdk@v1.1.101/pkg/traceability/traceability_test.go (about) 1 package traceability 2 3 import ( 4 "compress/gzip" 5 "context" 6 "encoding/json" 7 "io" 8 "io/ioutil" 9 "net/http" 10 "net/http/httptest" 11 "net/url" 12 "os" 13 "testing" 14 "time" 15 16 "github.com/Axway/agent-sdk/pkg/agent" 17 "github.com/Axway/agent-sdk/pkg/config" 18 "github.com/Axway/agent-sdk/pkg/traceability/sampling" 19 "github.com/elastic/beats/v7/libbeat/beat" 20 "github.com/elastic/beats/v7/libbeat/common" 21 "github.com/elastic/beats/v7/libbeat/outputs" 22 "github.com/elastic/beats/v7/libbeat/publisher" 23 "github.com/stretchr/testify/assert" 24 ) 25 26 var logstashClientCreateCalled = false 27 28 func init() { 29 logstashFactory := func( 30 indexManager outputs.IndexManager, 31 beat beat.Info, 32 observer outputs.Observer, 33 cfg *common.Config, 34 ) (outputs.Group, error) { 35 logstashClientCreateCalled = true 36 return outputs.SuccessNet(false, 1, 1, nil) 37 } 38 outputs.RegisterType("logstash", logstashFactory) 39 } 40 41 func createCentralCfg(url, env string) *config.CentralConfiguration { 42 cfg := config.NewCentralConfig(config.DiscoveryAgent).(*config.CentralConfiguration) 43 cfg.URL = url 44 cfg.SingleURL = "" 45 cfg.TenantID = "123456" 46 cfg.Environment = env 47 authCfg := cfg.Auth.(*config.AuthConfiguration) 48 authCfg.URL = url + "/auth" 49 authCfg.Realm = "Broker" 50 authCfg.ClientID = "serviceaccount_1234" 51 authCfg.PrivateKey = "../transaction/testdata/private_key.pem" 52 authCfg.PublicKey = "../transaction/testdata/public_key" 53 cfg.GetMetricReportingConfig().(*config.MetricReportingConfiguration).Schedule = "* * * * *" // every minute 54 cfg.GetUsageReportingConfig().(*config.UsageReportingConfiguration).Offline = false 55 return cfg 56 } 57 58 func createTransport(config *Config) (outputs.Group, error) { 59 info := beat.Info{ 60 Beat: "test-beat", 61 IndexPrefix: "", 62 Version: "1.0", 63 } 64 // defcfg := DefaultConfig() 65 commonCfg, _ := common.NewConfigFrom(config) 66 return makeTraceabilityAgent(nil, info, nil, commonCfg) 67 } 68 69 func createBatch(msgValue string) *MockBatch { 70 return &MockBatch{ 71 acked: false, 72 retryCount: 0, 73 events: createEvent(msgValue), 74 } 75 } 76 77 func createEvent(msgValue string) []publisher.Event { 78 fieldsData := common.MapStr{ 79 "message": msgValue, 80 } 81 return []publisher.Event{ 82 { 83 Content: beat.Event{ 84 Timestamp: time.Now(), 85 Meta: common.MapStr{sampling.SampleKey: true}, 86 Private: nil, 87 Fields: fieldsData, 88 }, 89 }, 90 } 91 } 92 93 type mockHTTPServer struct { 94 serverMessages []map[string]interface{} 95 responseStatus int 96 requestUserAgent string 97 98 server *httptest.Server 99 } 100 101 func newMockHTTPServer() *mockHTTPServer { 102 mockServer := &mockHTTPServer{} 103 mockServer.server = httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { 104 switch req.RequestURI { 105 case "/auth/realms/Broker/protocol/openid-connect/token": 106 token := "{\"access_token\":\"somevalue\",\"expires_in\": 12235677}" 107 resp.Write([]byte(token)) 108 case "/": 109 if req.Method == "POST" { 110 if mockServer.responseStatus != 0 { 111 resp.WriteHeader(mockServer.responseStatus) 112 return 113 } 114 mockServer.requestUserAgent = req.Header.Get("User-Agent") 115 mockServer.ResetMessages() 116 var body []byte 117 contentEncoding := req.Header["Content-Encoding"] 118 if contentEncoding != nil && contentEncoding[0] == "gzip" { 119 body, _ = mockServer.decompressGzipContent(req.Body) 120 } else { 121 body, _ = ioutil.ReadAll(req.Body) 122 } 123 json.Unmarshal(body, &mockServer.serverMessages) 124 resp.Write([]byte("ok")) 125 } 126 resp.Write([]byte("ok")) 127 } 128 })) 129 return mockServer 130 } 131 132 func (s *mockHTTPServer) ResetStatus() { 133 s.responseStatus = 0 134 } 135 136 func (s *mockHTTPServer) ResetMessages() { 137 s.serverMessages = make([]map[string]interface{}, 0) 138 } 139 140 func (s *mockHTTPServer) GetMessages() []map[string]interface{} { 141 return s.serverMessages 142 } 143 144 func (s *mockHTTPServer) GetUserAgent() string { 145 return s.requestUserAgent 146 } 147 148 func (s *mockHTTPServer) Close() { 149 s.server.Close() 150 } 151 func (s *mockHTTPServer) decompressGzipContent(gzipBufferReader io.Reader) ([]byte, error) { 152 gzipReader, err := gzip.NewReader(gzipBufferReader) 153 if err != nil { 154 return nil, err 155 } 156 plainContent, err := ioutil.ReadAll(gzipReader) 157 if err != nil { 158 return nil, err 159 } 160 return plainContent, nil 161 } 162 163 type MockBatch struct { 164 acked bool 165 retryCount int 166 167 events []publisher.Event 168 } 169 170 func (b *MockBatch) Events() []publisher.Event { return b.events } 171 func (b *MockBatch) ACK() { b.acked = true } 172 func (b *MockBatch) Drop() {} 173 func (b *MockBatch) Retry() {} 174 func (b *MockBatch) Cancelled() {} 175 func (b *MockBatch) RetryEvents(events []publisher.Event) { b.retryCount++ } 176 func (b *MockBatch) CancelledEvents(events []publisher.Event) {} 177 178 type testEventProcessor struct { 179 msgValue string 180 } 181 182 func (t *testEventProcessor) Process(events []publisher.Event) []publisher.Event { 183 return createEvent(t.msgValue) 184 } 185 186 func TestCreateLogstashClient(t *testing.T) { 187 s := newMockHTTPServer() 188 defer s.Close() 189 190 cfg := createCentralCfg(s.server.URL, "v7") 191 agent.Initialize(cfg) 192 193 group, err := createTransport(nil) 194 195 assert.NotNil(t, err) 196 assert.Contains(t, err.Error(), "config is nil") 197 assert.NotNil(t, group) 198 assert.Nil(t, group.Clients) 199 assert.False(t, logstashClientCreateCalled) 200 testConfig := DefaultConfig() 201 202 group, err = createTransport(testConfig) 203 assert.NotNil(t, err) 204 assert.Contains(t, err.Error(), "empty array accessing 'hosts'") 205 assert.NotNil(t, group) 206 assert.Nil(t, group.Clients) 207 assert.False(t, logstashClientCreateCalled) 208 209 testConfig.Hosts = []string{ 210 "somehost", 211 "someotherhost", 212 } 213 group, err = createTransport(testConfig) 214 assert.Nil(t, err) 215 assert.NotNil(t, group) 216 assert.NotNil(t, group.Clients) 217 assert.True(t, logstashClientCreateCalled) 218 219 testConfig.Pipelining = 5 220 testConfig.Hosts = []string{ 221 "somehost2", 222 } 223 group, err = createTransport(testConfig) 224 assert.Nil(t, err) 225 assert.NotNil(t, group) 226 assert.True(t, logstashClientCreateCalled) 227 traceabilityClient := group.Clients[0].(*Client) 228 assert.NotNil(t, traceabilityClient) 229 assert.False(t, IsHTTPTransport()) 230 assert.Equal(t, 3, GetMaxRetries()) 231 } 232 233 func TestCreateLogstashClientWithSingleEntry(t *testing.T) { 234 cfg := createCentralCfg("http://localhost:8888", "v7") 235 cfg.SingleURL = "http://localhost:9999" 236 agent.Initialize(cfg) 237 logstashClientCreateCalled = false 238 239 testConfig := DefaultConfig() 240 testConfig.Protocol = "http" 241 testConfig.Hosts = []string{ 242 "somehost", 243 } 244 group, err := createTransport(testConfig) 245 assert.Nil(t, err) 246 assert.NotNil(t, group) 247 assert.NotNil(t, group.Clients) 248 assert.True(t, logstashClientCreateCalled) 249 assert.Equal(t, "tcp", traceCfg.Protocol) 250 transportProxy := os.Getenv("TRACEABILITY_PROXYURL") 251 assert.Equal(t, "sni://"+traceCfg.Hosts[0], transportProxy) 252 253 testConfig.Proxy = ProxyConfig{ 254 URL: "http://localhost:9999", 255 LocalResolve: false, 256 } 257 258 testConfig.Hosts = []string{ 259 "somehost", 260 } 261 group, err = createTransport(testConfig) 262 assert.Nil(t, err) 263 assert.NotNil(t, group) 264 assert.NotNil(t, group.Clients) 265 assert.True(t, logstashClientCreateCalled) 266 assert.Equal(t, "tcp", traceCfg.Protocol) 267 assert.Equal(t, "http://localhost:9999", traceCfg.Proxy.URL) 268 transportProxy = os.Getenv("TRACEABILITY_PROXYURL") 269 assert.Equal(t, "sni://"+traceCfg.Hosts[0], transportProxy) 270 } 271 272 func TestCreateHTTPClient(t *testing.T) { 273 logstashClientCreateCalled = false 274 cfg := createCentralCfg("http://localhost:8888", "v7") 275 agent.Initialize(cfg) 276 277 testConfig := DefaultConfig() 278 testConfig.Protocol = "http" 279 280 testConfig.Hosts = []string{ 281 "somehost:invalidport", 282 } 283 284 group, err := createTransport(testConfig) 285 assert.NotNil(t, err) 286 assert.Contains(t, err.Error(), "invalid port") 287 assert.NotNil(t, group) 288 assert.Nil(t, group.Clients) 289 assert.False(t, logstashClientCreateCalled) 290 291 testConfig.Hosts = []string{ 292 "somehost", 293 } 294 testConfig.Proxy = ProxyConfig{ 295 URL: "bogus\\:bogus", 296 } 297 298 group, err = createTransport(testConfig) 299 assert.NotNil(t, err) 300 assert.NotNil(t, group) 301 assert.Nil(t, group.Clients) 302 assert.False(t, logstashClientCreateCalled) 303 304 testConfig.Proxy = ProxyConfig{} 305 testConfig.CompressionLevel = 20 306 group, err = createTransport(testConfig) 307 assert.NotNil(t, err) 308 assert.Contains(t, err.Error(), "requires value <= 9 accessing 'compression_level'") 309 assert.NotNil(t, group) 310 assert.Nil(t, group.Clients) 311 assert.False(t, logstashClientCreateCalled) 312 313 testConfig.CompressionLevel = 0 314 group, err = createTransport(testConfig) 315 assert.Nil(t, err) 316 assert.NotNil(t, group) 317 assert.Equal(t, 1, len(group.Clients)) 318 traceabilityClient := group.Clients[0].(*Client) 319 assert.NotNil(t, traceabilityClient) 320 assert.False(t, logstashClientCreateCalled) 321 assert.True(t, IsHTTPTransport()) 322 assert.Equal(t, 3, GetMaxRetries()) 323 } 324 325 func TestHTTPTransportWithJSONEncoding(t *testing.T) { 326 s := newMockHTTPServer() 327 defer s.Close() 328 config.AgentTypeName = "TraceabilityAgent" 329 config.AgentVersion = "0.0.1-abc" 330 config.SDKVersion = "0.0.1" 331 332 cfg := createCentralCfg(s.server.URL, "v7") 333 agent.Initialize(cfg) 334 335 url, _ := url.Parse(s.server.URL) 336 testConfig := DefaultConfig() 337 testConfig.Protocol = "http" 338 testConfig.CompressionLevel = 0 339 testConfig.Hosts = []string{url.Hostname() + ":" + url.Port()} 340 341 group, err := createTransport(testConfig) 342 assert.Nil(t, err) 343 assert.NotNil(t, group) 344 traceabilityClient := group.Clients[0].(*Client) 345 batch := createBatch("{\"f1\":\"test\"}") 346 traceabilityClient.Connect() 347 agent.StartAgentStatusUpdate() 348 err = traceabilityClient.Publish(context.Background(), batch) 349 traceabilityClient.Close() 350 351 assert.Nil(t, err) 352 publishedMessages := s.GetMessages() 353 reqUA := s.GetUserAgent() 354 assert.NotEmpty(t, reqUA) 355 assert.NotNil(t, publishedMessages) 356 assert.Equal(t, 1, len(publishedMessages)) 357 event := publishedMessages[0] 358 assert.Nil(t, err) 359 assert.Equal(t, "test", event["f1"]) 360 assert.True(t, batch.acked) 361 } 362 363 func TestHTTPTransportWithOutputProcessor(t *testing.T) { 364 s := newMockHTTPServer() 365 defer s.Close() 366 367 cfg := createCentralCfg(s.server.URL, "v7") 368 agent.Initialize(cfg) 369 370 url, _ := url.Parse(s.server.URL) 371 testConfig := DefaultConfig() 372 testConfig.Protocol = "http" 373 testConfig.CompressionLevel = 0 374 testConfig.Hosts = []string{ 375 url.Hostname() + ":" + url.Port(), 376 } 377 378 eventProcessor := &testEventProcessor{msgValue: "{\"f1\":\"test\"}"} 379 SetOutputEventProcessor(eventProcessor) 380 group, err := createTransport(testConfig) 381 assert.Nil(t, err) 382 traceabilityClient := group.Clients[0].(*Client) 383 batch := createBatch("{\"f0\":\"dummy\"}") 384 385 traceabilityClient.Connect() 386 agent.StartAgentStatusUpdate() 387 err = traceabilityClient.Publish(context.Background(), batch) 388 traceabilityClient.Close() 389 assert.Nil(t, err) 390 391 publishedMessages := s.GetMessages() 392 assert.NotNil(t, publishedMessages) 393 assert.Equal(t, 1, len(publishedMessages)) 394 event := publishedMessages[0] 395 assert.Equal(t, "test", event["f1"]) 396 assert.Nil(t, event["f0"]) 397 assert.True(t, batch.acked) 398 399 SetOutputEventProcessor(nil) 400 } 401 402 func TestHTTPTransportWithGzipEncoding(t *testing.T) { 403 s := newMockHTTPServer() 404 defer s.Close() 405 406 cfg := createCentralCfg(s.server.URL, "v7") 407 agent.Initialize(cfg) 408 409 url, _ := url.Parse(s.server.URL) 410 testConfig := DefaultConfig() 411 testConfig.Protocol = "http" 412 testConfig.CompressionLevel = 3 413 testConfig.Hosts = []string{ 414 url.Hostname() + ":" + url.Port(), 415 } 416 417 group, err := createTransport(testConfig) 418 assert.Nil(t, err) 419 assert.NotNil(t, group) 420 traceabilityClient := group.Clients[0].(*Client) 421 batch := createBatch("{\"f1\":\"test\"}") 422 423 traceabilityClient.Connect() 424 err = traceabilityClient.Publish(context.Background(), batch) 425 assert.Nil(t, err) 426 traceabilityClient.Close() 427 428 publishedMessages := s.GetMessages() 429 assert.NotNil(t, publishedMessages) 430 assert.Equal(t, 1, len(publishedMessages)) 431 432 event := publishedMessages[0] 433 434 assert.Nil(t, err) 435 assert.Equal(t, "test", event["f1"]) 436 assert.True(t, batch.acked) 437 } 438 439 func TestHTTPTransportRetries(t *testing.T) { 440 s := newMockHTTPServer() 441 defer s.Close() 442 443 cfg := createCentralCfg(s.server.URL, "v7") 444 agent.Initialize(cfg) 445 446 url, _ := url.Parse(s.server.URL) 447 testConfig := DefaultConfig() 448 testConfig.Protocol = "http" 449 testConfig.CompressionLevel = 0 450 testConfig.Hosts = []string{ 451 url.Hostname() + ":" + url.Port(), 452 } 453 454 group, err := createTransport(testConfig) 455 assert.Nil(t, err) 456 traceabilityClient := group.Clients[0].(*Client) 457 batch := createBatch("somemessage") 458 459 s.responseStatus = 404 460 traceabilityClient.Connect() 461 err = traceabilityClient.Publish(context.Background(), batch) 462 traceabilityClient.Close() 463 assert.NotNil(t, err) 464 assert.False(t, batch.acked) 465 assert.Equal(t, 1, batch.retryCount) 466 467 s.responseStatus = 500 468 batch = createBatch("somemessage") 469 group, err = createTransport(testConfig) 470 assert.Nil(t, err) 471 472 traceabilityClient = group.Clients[0].(*Client) 473 traceabilityClient.Connect() 474 err = traceabilityClient.Publish(context.Background(), batch) 475 traceabilityClient.Close() 476 assert.NotNil(t, err) 477 assert.False(t, batch.acked) 478 assert.Equal(t, 1, batch.retryCount) 479 publishedMessages := s.GetMessages() 480 assert.Nil(t, publishedMessages) 481 482 SetOutputEventProcessor(nil) 483 }