github.com/aavshr/aws-sdk-go@v1.41.3/private/model/api/codegentest/service/awsendpointdiscoverytest/endpoint_discovery_test.go (about)

     1  //go:build go1.7
     2  // +build go1.7
     3  
     4  package awsendpointdiscoverytest
     5  
     6  import (
     7  	"strconv"
     8  	"strings"
     9  	"sync"
    10  	"sync/atomic"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/aavshr/aws-sdk-go/aws"
    15  	"github.com/aavshr/aws-sdk-go/aws/endpoints"
    16  	"github.com/aavshr/aws-sdk-go/aws/request"
    17  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    18  )
    19  
    20  func TestEndpointDiscoveryWithCustomEndpoint(t *testing.T) {
    21  	mockEndpointResolver := endpoints.ResolverFunc(func(service string, region string, opts ...func(options *endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
    22  		return endpoints.ResolvedEndpoint{
    23  			URL: "https://mockEndpointForDiscovery",
    24  		}, nil
    25  	})
    26  
    27  	cases := map[string]struct {
    28  		hasDiscoveryEnabled bool
    29  		hasCustomEndpoint   bool
    30  		isOperationRequired bool
    31  		customEndpoint      string
    32  		expectedEndpoint    string
    33  	}{
    34  		"HasCustomEndpoint_RequiredOperation": {
    35  			hasDiscoveryEnabled: true,
    36  			hasCustomEndpoint:   true,
    37  			isOperationRequired: true,
    38  			customEndpoint:      "https://mockCustomEndpoint",
    39  			expectedEndpoint:    "https://mockCustomEndpoint/",
    40  		},
    41  		"HasCustomEndpoint_OptionalOperation": {
    42  			hasDiscoveryEnabled: true,
    43  			hasCustomEndpoint:   true,
    44  			customEndpoint:      "https://mockCustomEndpoint",
    45  			expectedEndpoint:    "https://mockCustomEndpoint/",
    46  		},
    47  		"NoCustomEndpoint_DiscoveryDisabled": {
    48  			expectedEndpoint: "https://mockEndpointForDiscovery/",
    49  		},
    50  	}
    51  
    52  	for name, c := range cases {
    53  		t.Run(name, func(t *testing.T) {
    54  			cfg := &aws.Config{
    55  				EnableEndpointDiscovery: aws.Bool(c.hasDiscoveryEnabled),
    56  				EndpointResolver:        mockEndpointResolver,
    57  			}
    58  			if c.hasCustomEndpoint {
    59  				cfg.Endpoint = aws.String(c.customEndpoint)
    60  			}
    61  
    62  			svc := New(unit.Session, cfg)
    63  			svc.Handlers.Clear()
    64  			// Add a handler to verify no call goes to DescribeEndpoints operation
    65  			svc.Handlers.Send.PushBack(func(r *request.Request) {
    66  				if ne, a := opDescribeEndpoints, r.Operation.Name; strings.EqualFold(ne, a) {
    67  					t.Errorf("expected no call to %q operation", a)
    68  				}
    69  			})
    70  
    71  			var req *request.Request
    72  			if c.isOperationRequired {
    73  				req, _ = svc.TestDiscoveryIdentifiersRequiredRequest(
    74  					&TestDiscoveryIdentifiersRequiredInput{
    75  						Sdk: aws.String("sdk"),
    76  					},
    77  				)
    78  			} else {
    79  				req, _ = svc.TestDiscoveryOptionalRequest(
    80  					&TestDiscoveryOptionalInput{
    81  						Sdk: aws.String("sdk"),
    82  					},
    83  				)
    84  			}
    85  
    86  			req.Handlers.Send.PushBack(func(r *request.Request) {
    87  				if e, a := c.expectedEndpoint, r.HTTPRequest.URL.String(); e != a {
    88  					t.Errorf("expected %q, but received %q", e, a)
    89  				}
    90  			})
    91  			if err := req.Send(); err != nil {
    92  				t.Fatal(err)
    93  			}
    94  		})
    95  	}
    96  }
    97  
    98  func TestEndpointDiscoveryWithAttemptedDiscovery(t *testing.T) {
    99  	mockEndpointResolver := endpoints.ResolverFunc(func(service string, region string, opts ...func(options *endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
   100  		return endpoints.ResolvedEndpoint{
   101  			URL: "https://mockEndpointForDiscovery",
   102  		}, nil
   103  	})
   104  
   105  	cases := map[string]struct {
   106  		hasDiscoveryEnabled bool
   107  		hasCustomEndpoint   bool
   108  		isOperationRequired bool
   109  		customEndpoint      string
   110  		expectedEndpoint    string
   111  	}{
   112  		"NoCustomEndpoint_RequiredOperation": {
   113  			hasDiscoveryEnabled: true,
   114  			isOperationRequired: true,
   115  			expectedEndpoint:    "https://mockEndpointForDiscovery/",
   116  		},
   117  		"NoCustomEndpoint_OptionalOperation": {
   118  			hasDiscoveryEnabled: true,
   119  			expectedEndpoint:    "https://mockEndpointForDiscovery/",
   120  		},
   121  	}
   122  
   123  	for name, c := range cases {
   124  		t.Run(name, func(t *testing.T) {
   125  			cfg := &aws.Config{
   126  				EnableEndpointDiscovery: aws.Bool(c.hasDiscoveryEnabled),
   127  				EndpointResolver:        mockEndpointResolver,
   128  			}
   129  			if c.hasCustomEndpoint {
   130  				cfg.Endpoint = aws.String(c.customEndpoint)
   131  			}
   132  
   133  			svc := New(unit.Session, cfg)
   134  			svc.Handlers.Clear()
   135  			req, _ := svc.TestDiscoveryIdentifiersRequiredRequest(
   136  				&TestDiscoveryIdentifiersRequiredInput{
   137  					Sdk: aws.String("sdk"),
   138  				},
   139  			)
   140  
   141  			svc.Handlers.Send.PushBack(func(r *request.Request) {
   142  				if e, a := opDescribeEndpoints, r.Operation.Name; e != a {
   143  					t.Fatalf("expected operaton to be %q, called %q instead", e, a)
   144  				}
   145  			})
   146  
   147  			req.Handlers.Send.PushBack(func(r *request.Request) {
   148  				if e, a := c.expectedEndpoint, r.HTTPRequest.URL.String(); e != a {
   149  					t.Errorf("expected %q, but received %q", e, a)
   150  				}
   151  			})
   152  
   153  			if err := req.Send(); err != nil {
   154  				t.Fatal(err)
   155  			}
   156  		})
   157  	}
   158  }
   159  
   160  func TestEndpointDiscovery(t *testing.T) {
   161  	svc := New(unit.Session, &aws.Config{
   162  		EnableEndpointDiscovery: aws.Bool(true),
   163  	})
   164  	svc.Handlers.Clear()
   165  	svc.Handlers.Send.PushBack(mockSendDescEndpoint("http://foo"))
   166  
   167  	var descCount int32
   168  	svc.Handlers.Complete.PushBack(func(r *request.Request) {
   169  		if r.Operation.Name != opDescribeEndpoints {
   170  			return
   171  		}
   172  		atomic.AddInt32(&descCount, 1)
   173  	})
   174  
   175  	for i := 0; i < 2; i++ {
   176  		req, _ := svc.TestDiscoveryIdentifiersRequiredRequest(
   177  			&TestDiscoveryIdentifiersRequiredInput{
   178  				Sdk: aws.String("sdk"),
   179  			},
   180  		)
   181  		req.Handlers.Send.PushBack(func(r *request.Request) {
   182  			if e, a := "http://foo", r.HTTPRequest.URL.String(); e != a {
   183  				t.Errorf("expected %q, but received %q", e, a)
   184  			}
   185  		})
   186  		if err := req.Send(); err != nil {
   187  			t.Fatal(err)
   188  		}
   189  	}
   190  
   191  	if e, a := int32(1), atomic.LoadInt32(&descCount); e != a {
   192  		t.Errorf("expect desc endpoint called %d, got %d", e, a)
   193  	}
   194  }
   195  
   196  func TestAsyncEndpointDiscovery(t *testing.T) {
   197  	t.Parallel()
   198  
   199  	svc := New(unit.Session, &aws.Config{
   200  		EnableEndpointDiscovery: aws.Bool(true),
   201  	})
   202  	svc.Handlers.Clear()
   203  
   204  	var firstAsyncReq sync.WaitGroup
   205  	firstAsyncReq.Add(1)
   206  	svc.Handlers.Build.PushBack(func(r *request.Request) {
   207  		if r.Operation.Name == opDescribeEndpoints {
   208  			firstAsyncReq.Wait()
   209  		}
   210  	})
   211  	svc.Handlers.Send.PushBack(mockSendDescEndpoint("http://foo"))
   212  
   213  	req, _ := svc.TestDiscoveryOptionalRequest(&TestDiscoveryOptionalInput{
   214  		Sdk: aws.String("sdk"),
   215  	})
   216  	const clientHost = "awsendpointdiscoverytestservice.mock-region.amazonaws.com"
   217  	req.Handlers.Send.PushBack(func(r *request.Request) {
   218  		if e, a := clientHost, r.HTTPRequest.URL.Host; e != a {
   219  			t.Errorf("expected %q, but received %q", e, a)
   220  		}
   221  	})
   222  	req.Handlers.Complete.PushBack(func(r *request.Request) {
   223  		firstAsyncReq.Done()
   224  	})
   225  	if err := req.Send(); err != nil {
   226  		t.Fatal(err)
   227  	}
   228  
   229  	var cacheUpdated bool
   230  	for s := time.Now().Add(10 * time.Second); s.After(time.Now()); {
   231  		// Wait for the cache to be updated before making second request.
   232  		if svc.endpointCache.Has(req.Operation.Name) {
   233  			cacheUpdated = true
   234  			break
   235  		}
   236  		time.Sleep(10 * time.Millisecond)
   237  	}
   238  	if !cacheUpdated {
   239  		t.Fatalf("expect endpoint cache to be updated, was not")
   240  	}
   241  
   242  	req, _ = svc.TestDiscoveryOptionalRequest(&TestDiscoveryOptionalInput{
   243  		Sdk: aws.String("sdk"),
   244  	})
   245  	req.Handlers.Send.PushBack(func(r *request.Request) {
   246  		if e, a := "http://foo", r.HTTPRequest.URL.String(); e != a {
   247  			t.Errorf("expected %q, but received %q", e, a)
   248  		}
   249  	})
   250  	if err := req.Send(); err != nil {
   251  		t.Fatal(err)
   252  	}
   253  }
   254  
   255  func TestEndpointDiscovery_EndpointScheme(t *testing.T) {
   256  	cases := []struct {
   257  		address         string
   258  		expectedAddress string
   259  		err             string
   260  	}{
   261  		0: {
   262  			address:         "https://foo",
   263  			expectedAddress: "https://foo",
   264  		},
   265  		1: {
   266  			address:         "bar",
   267  			expectedAddress: "https://bar",
   268  		},
   269  	}
   270  
   271  	for i, c := range cases {
   272  		t.Run(strconv.Itoa(i), func(t *testing.T) {
   273  			svc := New(unit.Session, &aws.Config{
   274  				EnableEndpointDiscovery: aws.Bool(true),
   275  			})
   276  			svc.Handlers.Clear()
   277  			svc.Handlers.Send.PushBack(mockSendDescEndpoint(c.address))
   278  
   279  			for i := 0; i < 2; i++ {
   280  				req, _ := svc.TestDiscoveryIdentifiersRequiredRequest(
   281  					&TestDiscoveryIdentifiersRequiredInput{
   282  						Sdk: aws.String("sdk"),
   283  					},
   284  				)
   285  				req.Handlers.Send.PushBack(func(r *request.Request) {
   286  					if len(c.err) == 0 {
   287  						if e, a := c.expectedAddress, r.HTTPRequest.URL.String(); e != a {
   288  							t.Errorf("expected %q, but received %q", e, a)
   289  						}
   290  					}
   291  				})
   292  
   293  				err := req.Send()
   294  				if err != nil && len(c.err) == 0 {
   295  					t.Fatalf("expected no error, got %v", err)
   296  				} else if err == nil && len(c.err) > 0 {
   297  					t.Fatalf("expected error, got none")
   298  				} else if err != nil && len(c.err) > 0 {
   299  					if e, a := c.err, err.Error(); !strings.Contains(a, e) {
   300  						t.Fatalf("expected %v, got %v", c.err, err)
   301  					}
   302  				}
   303  			}
   304  		})
   305  	}
   306  }
   307  
   308  func removeHandlers(h request.Handlers, removeSendHandlers bool) request.Handlers {
   309  	if removeSendHandlers {
   310  		h.Send.Clear()
   311  	}
   312  	h.Unmarshal.Clear()
   313  	h.UnmarshalStream.Clear()
   314  	h.UnmarshalMeta.Clear()
   315  	h.UnmarshalError.Clear()
   316  	h.Validate.Clear()
   317  	h.Complete.Clear()
   318  	h.ValidateResponse.Clear()
   319  	return h
   320  }
   321  
   322  func mockSendDescEndpoint(address string) func(r *request.Request) {
   323  	return func(r *request.Request) {
   324  		if r.Operation.Name != opDescribeEndpoints {
   325  			return
   326  		}
   327  
   328  		out, _ := r.Data.(*DescribeEndpointsOutput)
   329  		out.Endpoints = []*Endpoint{
   330  			{
   331  				Address:              &address,
   332  				CachePeriodInMinutes: aws.Int64(5),
   333  			},
   334  		}
   335  		r.Data = out
   336  	}
   337  }