github.com/aavshr/aws-sdk-go@v1.41.3/private/model/api/smoke.go (about)

     1  //go:build codegen
     2  // +build codegen
     3  
     4  package api
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"fmt"
    10  	"os"
    11  	"sort"
    12  	"text/template"
    13  )
    14  
    15  // SmokeTestSuite defines the test suite for smoke tests.
    16  type SmokeTestSuite struct {
    17  	Version       int             `json:"version"`
    18  	DefaultRegion string          `json:"defaultRegion"`
    19  	TestCases     []SmokeTestCase `json:"testCases"`
    20  }
    21  
    22  // SmokeTestCase provides the definition for a integration smoke test case.
    23  type SmokeTestCase struct {
    24  	OpName    string                 `json:"operationName"`
    25  	Input     map[string]interface{} `json:"input"`
    26  	ExpectErr bool                   `json:"errorExpectedFromService"`
    27  }
    28  
    29  var smokeTestsCustomizations = map[string]func(*SmokeTestSuite) error{
    30  	"sts": stsSmokeTestCustomization,
    31  }
    32  
    33  func stsSmokeTestCustomization(suite *SmokeTestSuite) error {
    34  	const getSessionTokenOp = "GetSessionToken"
    35  	const getCallerIdentityOp = "GetCallerIdentity"
    36  
    37  	opTestMap := make(map[string][]SmokeTestCase)
    38  	for _, testCase := range suite.TestCases {
    39  		opTestMap[testCase.OpName] = append(opTestMap[testCase.OpName], testCase)
    40  	}
    41  
    42  	if _, ok := opTestMap[getSessionTokenOp]; ok {
    43  		delete(opTestMap, getSessionTokenOp)
    44  	}
    45  
    46  	if _, ok := opTestMap[getCallerIdentityOp]; !ok {
    47  		opTestMap[getCallerIdentityOp] = append(opTestMap[getCallerIdentityOp], SmokeTestCase{
    48  			OpName:    getCallerIdentityOp,
    49  			Input:     map[string]interface{}{},
    50  			ExpectErr: false,
    51  		})
    52  	}
    53  
    54  	var testCases []SmokeTestCase
    55  
    56  	var keys []string
    57  	for name := range opTestMap {
    58  		keys = append(keys, name)
    59  	}
    60  	sort.Strings(keys)
    61  	for _, name := range keys {
    62  		testCases = append(testCases, opTestMap[name]...)
    63  	}
    64  
    65  	suite.TestCases = testCases
    66  
    67  	return nil
    68  }
    69  
    70  // BuildInputShape returns the Go code as a string for initializing the test
    71  // case's input shape.
    72  func (c SmokeTestCase) BuildInputShape(ref *ShapeRef) string {
    73  	b := NewShapeValueBuilder()
    74  	return fmt.Sprintf("&%s{\n%s\n}",
    75  		b.GoType(ref, true),
    76  		b.BuildShape(ref, c.Input, false),
    77  	)
    78  }
    79  
    80  // AttachSmokeTests attaches the smoke test cases to the API model.
    81  func (a *API) AttachSmokeTests(filename string) error {
    82  	f, err := os.Open(filename)
    83  	if err != nil {
    84  		return fmt.Errorf("failed to open smoke tests %s, err: %v", filename, err)
    85  	}
    86  	defer f.Close()
    87  
    88  	if err := json.NewDecoder(f).Decode(&a.SmokeTests); err != nil {
    89  		return fmt.Errorf("failed to decode smoke tests %s, err: %v", filename, err)
    90  	}
    91  
    92  	if v := a.SmokeTests.Version; v != 1 {
    93  		return fmt.Errorf("invalid smoke test version, %d", v)
    94  	}
    95  
    96  	if fn, ok := smokeTestsCustomizations[a.PackageName()]; ok {
    97  		if err := fn(&a.SmokeTests); err != nil {
    98  			return err
    99  		}
   100  	}
   101  
   102  	return nil
   103  }
   104  
   105  // APISmokeTestsGoCode returns the Go Code string for the smoke tests.
   106  func (a *API) APISmokeTestsGoCode() string {
   107  	w := bytes.NewBuffer(nil)
   108  
   109  	a.resetImports()
   110  	a.AddImport("context")
   111  	a.AddImport("testing")
   112  	a.AddImport("time")
   113  	a.AddSDKImport("aws")
   114  	a.AddSDKImport("aws/request")
   115  	a.AddSDKImport("aws/awserr")
   116  	a.AddSDKImport("aws/request")
   117  	a.AddSDKImport("awstesting/integration")
   118  	a.AddImport(a.ImportPath())
   119  
   120  	smokeTests := struct {
   121  		API *API
   122  		SmokeTestSuite
   123  	}{
   124  		API:            a,
   125  		SmokeTestSuite: a.SmokeTests,
   126  	}
   127  
   128  	if err := smokeTestTmpl.Execute(w, smokeTests); err != nil {
   129  		panic(fmt.Sprintf("failed to create smoke tests, %v", err))
   130  	}
   131  
   132  	ignoreImports := `
   133  	var _ aws.Config
   134  	var _ awserr.Error
   135  	var _ request.Request
   136  	`
   137  
   138  	return a.importsGoCode() + ignoreImports + w.String()
   139  }
   140  
   141  var smokeTestTmpl = template.Must(template.New(`smokeTestTmpl`).Parse(`
   142  {{- range $i, $testCase := $.TestCases }}
   143  	{{- $op := index $.API.Operations $testCase.OpName }}
   144  	func TestInteg_{{ printf "%02d" $i }}_{{ $op.ExportedName }}(t *testing.T) {
   145  		ctx, cancelFn := context.WithTimeout(context.Background(), 5 *time.Second)
   146  		defer cancelFn()
   147  	
   148  		sess := integration.SessionWithDefaultRegion("{{ $.DefaultRegion }}")
   149  		svc := {{ $.API.PackageName }}.New(sess)
   150  		params := {{ $testCase.BuildInputShape $op.InputRef }}
   151  		_, err := svc.{{ $op.ExportedName }}WithContext(ctx, params, func(r *request.Request) {
   152  			r.Handlers.Validate.RemoveByName("core.ValidateParametersHandler")
   153  		})
   154  		{{- if $testCase.ExpectErr }}
   155  			if err == nil {
   156  				t.Fatalf("expect request to fail")
   157  			}
   158  			aerr, ok := err.(awserr.RequestFailure)
   159  			if !ok {
   160  				t.Fatalf("expect awserr, was %T", err)
   161  			}
   162  			if len(aerr.Code()) == 0 {
   163  				t.Errorf("expect non-empty error code")
   164  			}
   165  			if len(aerr.Message()) == 0 {
   166  				t.Errorf("expect non-empty error message")
   167  			}
   168  			if v := aerr.Code(); v == request.ErrCodeSerialization {
   169  				t.Errorf("expect API error code got serialization failure")
   170  			}
   171  		{{- else }}
   172  			if err != nil {
   173  				t.Errorf("expect no error, got %v", err)
   174  			}
   175  		{{- end }}
   176  	}
   177  {{- end }}
   178  `))