github.com/aavshr/aws-sdk-go@v1.41.3/private/model/cli/gen-protocol-tests/main.go (about)

     1  //go:build codegen
     2  // +build codegen
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"fmt"
    10  	"net/url"
    11  	"os"
    12  	"os/exec"
    13  	"reflect"
    14  	"regexp"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"text/template"
    19  
    20  	"github.com/aavshr/aws-sdk-go/private/model/api"
    21  	"github.com/aavshr/aws-sdk-go/private/util"
    22  )
    23  
    24  // TestSuiteTypeInput input test
    25  // TestSuiteTypeInput output test
    26  const (
    27  	TestSuiteTypeInput = iota
    28  	TestSuiteTypeOutput
    29  )
    30  
    31  type testSuite struct {
    32  	*api.API
    33  	Description    string
    34  	ClientEndpoint string
    35  	Cases          []testCase
    36  	Type           uint
    37  	title          string
    38  }
    39  
    40  func (s *testSuite) UnmarshalJSON(p []byte) error {
    41  	type stub testSuite
    42  
    43  	var v stub
    44  	if err := json.Unmarshal(p, &v); err != nil {
    45  		return err
    46  	}
    47  
    48  	if len(v.ClientEndpoint) == 0 {
    49  		v.ClientEndpoint = "https://test"
    50  	}
    51  	for i := 0; i < len(v.Cases); i++ {
    52  		if len(v.Cases[i].InputTest.Host) == 0 {
    53  			v.Cases[i].InputTest.Host = "test"
    54  		}
    55  		if len(v.Cases[i].InputTest.URI) == 0 {
    56  			v.Cases[i].InputTest.URI = "/"
    57  		}
    58  	}
    59  
    60  	*s = testSuite(v)
    61  	return nil
    62  }
    63  
    64  type testCase struct {
    65  	TestSuite  *testSuite
    66  	Given      *api.Operation
    67  	Params     interface{}     `json:",omitempty"`
    68  	Data       interface{}     `json:"result,omitempty"`
    69  	InputTest  testExpectation `json:"serialized"`
    70  	OutputTest testExpectation `json:"response"`
    71  }
    72  
    73  type testExpectation struct {
    74  	Body          string
    75  	Host          string
    76  	URI           string
    77  	Headers       map[string]string
    78  	ForbidHeaders []string
    79  	JSONValues    map[string]string
    80  	StatusCode    uint `json:"status_code"`
    81  }
    82  
    83  const preamble = `
    84  var _ bytes.Buffer // always import bytes
    85  var _ http.Request
    86  var _ json.Marshaler
    87  var _ time.Time
    88  var _ xmlutil.XMLNode
    89  var _ xml.Attr
    90  var _ = ioutil.Discard
    91  var _ = util.Trim("")
    92  var _ = url.Values{}
    93  var _ = io.EOF
    94  var _ = aws.String
    95  var _ = fmt.Println
    96  var _ = reflect.Value{}
    97  
    98  func init() {
    99  	protocol.RandReader = &awstesting.ZeroReader{}
   100  }
   101  `
   102  
   103  var reStripSpace = regexp.MustCompile(`\s(\w)`)
   104  
   105  var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
   106  
   107  func removeImports(code string) string {
   108  	return reImportRemoval.ReplaceAllString(code, "")
   109  }
   110  
   111  var extraImports = []string{
   112  	"bytes",
   113  	"encoding/json",
   114  	"encoding/xml",
   115  	"fmt",
   116  	"io",
   117  	"io/ioutil",
   118  	"net/http",
   119  	"testing",
   120  	"time",
   121  	"reflect",
   122  	"net/url",
   123  	"",
   124  	"github.com/aavshr/aws-sdk-go/awstesting",
   125  	"github.com/aavshr/aws-sdk-go/awstesting/unit",
   126  	"github.com/aavshr/aws-sdk-go/private/protocol",
   127  	"github.com/aavshr/aws-sdk-go/private/protocol/xml/xmlutil",
   128  	"github.com/aavshr/aws-sdk-go/private/util",
   129  }
   130  
   131  func addImports(code string) string {
   132  	importNames := make([]string, len(extraImports))
   133  	for i, n := range extraImports {
   134  		if n != "" {
   135  			importNames[i] = fmt.Sprintf("%q", n)
   136  		}
   137  	}
   138  
   139  	str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
   140  	return str
   141  }
   142  
   143  func (t *testSuite) TestSuite() string {
   144  	var buf bytes.Buffer
   145  
   146  	t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
   147  		return strings.ToUpper(x[1:])
   148  	})
   149  	t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
   150  
   151  	for idx, c := range t.Cases {
   152  		c.TestSuite = t
   153  		buf.WriteString(c.TestCase(idx) + "\n")
   154  	}
   155  	return buf.String()
   156  }
   157  
   158  var tplInputTestCase = template.Must(template.New("inputcase").
   159  	Funcs(template.FuncMap{
   160  		"stringsEqualFold": strings.EqualFold,
   161  	}).
   162  	Parse(`
   163  func Test{{ .OpName }}(t *testing.T) {
   164  	svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("{{ .TestCase.TestSuite.ClientEndpoint  }}")})
   165  	{{ if ne .ParamsString "" }}input := {{ .ParamsString }}
   166  	{{ range $k, $v := .JSONValues -}}
   167  	input.{{ $k }} = {{ $v }} 
   168  	{{ end -}}
   169  	req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }}
   170  	r := req.HTTPRequest
   171  
   172  	// build request
   173  	req.Build()
   174  	if req.Error != nil {
   175  		t.Errorf("expect no error, got %v", req.Error)
   176  	}
   177  	req.Sign()
   178  	if req.Error != nil {
   179  		t.Errorf("expect no error, got %v", req.Error)
   180  	}
   181  
   182  	{{- if ne .TestCase.InputTest.Body "" }}
   183  
   184  		// assert body
   185  		if r.Body == nil {
   186  			t.Errorf("expect body not to be nil")
   187  		}
   188  		{{ .BodyAssertions }}
   189  
   190  		if e, a := int64(len(body)), r.ContentLength; e != a {
   191  			t.Errorf("expect serialized body length to match %v ContentLength, got %v", e, a)
   192  		}
   193  	{{- end }}
   194  
   195  	// assert URL
   196  	awstesting.AssertURL(t, "https://{{ .TestCase.InputTest.Host }}{{ .TestCase.InputTest.URI }}", r.URL.String())
   197  
   198  	{{- if .TestCase.InputTest.Headers }}
   199  
   200  		// assert headers
   201  		{{- range $k, $v := .TestCase.InputTest.Headers }}
   202  			{{- if not (stringsEqualFold $k "Content-Length") }}
   203  				if e, a := "{{ $v }}", r.Header.Get("{{ $k }}"); e != a {
   204  					t.Errorf("expect {{ $k }} %v header value, got %v", e, a)
   205  				}
   206  			{{- end }}
   207  		{{- end }}
   208  	{{- end }}
   209  
   210  
   211  	{{- if .TestCase.InputTest.ForbidHeaders }}
   212  
   213  		// assert exclude headers
   214  		{{- range $_, $k := .TestCase.InputTest.ForbidHeaders }}
   215  			{{- if eq $k "Content-Length" }}
   216  				if v := r.ContentLength; v > 0 {
   217  					t.Errorf("expect no content-length, got %v", v)
   218  				}
   219  			{{- end }}
   220  			if v := r.Header.Get("{{ $k }}"); v != "" {
   221  				t.Errorf("expect not to have {{ $k }} header, got with value %v", v)
   222  			}
   223  		{{- end }}
   224  	{{- end }}
   225  }
   226  `))
   227  
   228  type tplInputTestCaseData struct {
   229  	TestCase             *testCase
   230  	JSONValues           map[string]string
   231  	OpName, ParamsString string
   232  }
   233  
   234  func (t tplInputTestCaseData) BodyAssertions() string {
   235  	code := &bytes.Buffer{}
   236  	protocol := t.TestCase.TestSuite.API.Metadata.Protocol
   237  
   238  	// Extract the body bytes
   239  	fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)")
   240  
   241  	// Generate the body verification code
   242  	expectedBody := util.Trim(t.TestCase.InputTest.Body)
   243  	switch protocol {
   244  	case "ec2", "query":
   245  		fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))",
   246  			expectedBody)
   247  	case "rest-xml":
   248  		if strings.HasPrefix(expectedBody, "<") {
   249  			fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)))",
   250  				expectedBody)
   251  		} else {
   252  			code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))"))
   253  		}
   254  	case "json", "jsonrpc", "rest-json":
   255  		if strings.HasPrefix(expectedBody, "{") {
   256  			fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))",
   257  				expectedBody)
   258  		} else {
   259  			code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))"))
   260  		}
   261  	default:
   262  		code.WriteString(fmtAssertEqual(expectedBody, "util.Trim(string(body))"))
   263  	}
   264  
   265  	return code.String()
   266  }
   267  
   268  func fmtAssertEqual(e, a string) string {
   269  	const format = `if e, a := %s, %s; e != a {
   270  		t.Errorf("expect %%v, got %%v", e, a)
   271  	}
   272  	`
   273  
   274  	return fmt.Sprintf(format, e, a)
   275  }
   276  
   277  func fmtAssertNil(v string) string {
   278  	const format = `if e := %s; e != nil {
   279  		t.Errorf("expect nil, got %%v", e)
   280  	}
   281  	`
   282  
   283  	return fmt.Sprintf(format, v)
   284  }
   285  
   286  var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
   287  func Test{{ .OpName }}(t *testing.T) {
   288  	svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
   289  
   290  	buf := bytes.NewReader([]byte({{ .Body }}))
   291  	req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
   292  	req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
   293  
   294  	// set headers
   295  	{{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
   296  	{{ end }}
   297  
   298  	// unmarshal response
   299  	req.Handlers.UnmarshalMeta.Run(req)
   300  	req.Handlers.Unmarshal.Run(req)
   301  	if req.Error != nil {
   302  		t.Errorf("expect not error, got %v", req.Error)
   303  	}
   304  
   305  	// assert response
   306  	if out == nil {
   307  		t.Errorf("expect not to be nil")
   308  	}
   309  	{{ .Assertions }}
   310  }
   311  `))
   312  
   313  type tplOutputTestCaseData struct {
   314  	TestCase                 *testCase
   315  	Body, OpName, Assertions string
   316  }
   317  
   318  func (i *testCase) TestCase(idx int) string {
   319  	var buf bytes.Buffer
   320  
   321  	opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
   322  
   323  	if i.TestSuite.Type == TestSuiteTypeInput { // input test
   324  		// query test should sort body as form encoded values
   325  		switch i.TestSuite.API.Metadata.Protocol {
   326  		case "query", "ec2":
   327  			m, _ := url.ParseQuery(i.InputTest.Body)
   328  			i.InputTest.Body = m.Encode()
   329  		case "rest-xml":
   330  			// Nothing to do
   331  		case "json", "rest-json":
   332  			// Nothing to do
   333  		}
   334  
   335  		jsonValues := buildJSONValues(i.Given.InputRef.Shape)
   336  		var params interface{}
   337  		if m, ok := i.Params.(map[string]interface{}); ok {
   338  			paramsMap := map[string]interface{}{}
   339  			for k, v := range m {
   340  				if _, ok := jsonValues[k]; !ok {
   341  					paramsMap[k] = v
   342  				} else {
   343  					if i.InputTest.JSONValues == nil {
   344  						i.InputTest.JSONValues = map[string]string{}
   345  					}
   346  					i.InputTest.JSONValues[k] = serializeJSONValue(v.(map[string]interface{}))
   347  				}
   348  			}
   349  			params = paramsMap
   350  		} else {
   351  			params = i.Params
   352  		}
   353  		input := tplInputTestCaseData{
   354  			TestCase:     i,
   355  			OpName:       strings.ToUpper(opName[0:1]) + opName[1:],
   356  			ParamsString: api.ParamsStructFromJSON(params, i.Given.InputRef.Shape, false),
   357  			JSONValues:   i.InputTest.JSONValues,
   358  		}
   359  
   360  		if err := tplInputTestCase.Execute(&buf, input); err != nil {
   361  			panic(err)
   362  		}
   363  	} else if i.TestSuite.Type == TestSuiteTypeOutput {
   364  		output := tplOutputTestCaseData{
   365  			TestCase:   i,
   366  			Body:       fmt.Sprintf("%q", i.OutputTest.Body),
   367  			OpName:     strings.ToUpper(opName[0:1]) + opName[1:],
   368  			Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
   369  		}
   370  
   371  		if err := tplOutputTestCase.Execute(&buf, output); err != nil {
   372  			panic(err)
   373  		}
   374  	}
   375  
   376  	return buf.String()
   377  }
   378  
   379  func serializeJSONValue(m map[string]interface{}) string {
   380  	str := "aws.JSONValue"
   381  	str += walkMap(m)
   382  	return str
   383  }
   384  
   385  func walkMap(m map[string]interface{}) string {
   386  	str := "{"
   387  	for k, v := range m {
   388  		str += fmt.Sprintf("%q:", k)
   389  		switch v.(type) {
   390  		case bool:
   391  			str += fmt.Sprintf("%t,\n", v.(bool))
   392  		case string:
   393  			str += fmt.Sprintf("%q,\n", v.(string))
   394  		case int:
   395  			str += fmt.Sprintf("%d,\n", v.(int))
   396  		case float64:
   397  			str += fmt.Sprintf("%f,\n", v.(float64))
   398  		case map[string]interface{}:
   399  			str += walkMap(v.(map[string]interface{}))
   400  		}
   401  	}
   402  	str += "}"
   403  	return str
   404  }
   405  
   406  func buildJSONValues(shape *api.Shape) map[string]struct{} {
   407  	keys := map[string]struct{}{}
   408  	for key, field := range shape.MemberRefs {
   409  		if field.JSONValue {
   410  			keys[key] = struct{}{}
   411  		}
   412  	}
   413  	return keys
   414  }
   415  
   416  // generateTestSuite generates a protocol test suite for a given configuration
   417  // JSON protocol test file.
   418  func generateTestSuite(filename string) string {
   419  	inout := "Input"
   420  	if strings.Contains(filename, "output/") {
   421  		inout = "Output"
   422  	}
   423  
   424  	var suites []testSuite
   425  	f, err := os.Open(filename)
   426  	if err != nil {
   427  		panic(err)
   428  	}
   429  
   430  	err = json.NewDecoder(f).Decode(&suites)
   431  	if err != nil {
   432  		panic(err)
   433  	}
   434  
   435  	var buf bytes.Buffer
   436  	buf.WriteString("// Code generated by models/protocol_tests/generate.go. DO NOT EDIT.\n\n")
   437  	buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
   438  
   439  	var innerBuf bytes.Buffer
   440  	innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
   441  
   442  	for i, suite := range suites {
   443  		svcPrefix := inout + "Service" + strconv.Itoa(i+1)
   444  		suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
   445  		suite.API.Operations = map[string]*api.Operation{}
   446  		for idx, c := range suite.Cases {
   447  			c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
   448  			suite.API.Operations[c.Given.ExportedName] = c.Given
   449  		}
   450  
   451  		suite.Type = getType(inout)
   452  		suite.API.NoInitMethods = true       // don't generate init methods
   453  		suite.API.NoStringerMethods = true   // don't generate stringer methods
   454  		suite.API.NoConstServiceNames = true // don't generate service names
   455  		suite.API.Setup()
   456  		suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
   457  		suite.API.Metadata.EndpointsID = suite.API.Metadata.EndpointPrefix
   458  
   459  		// Sort in order for deterministic test generation
   460  		names := make([]string, 0, len(suite.API.Shapes))
   461  		for n := range suite.API.Shapes {
   462  			names = append(names, n)
   463  		}
   464  		sort.Strings(names)
   465  		for _, name := range names {
   466  			s := suite.API.Shapes[name]
   467  			s.Rename(svcPrefix + "TestShape" + name)
   468  		}
   469  
   470  		svcCode := addImports(suite.API.ServiceGoCode())
   471  		if i == 0 {
   472  			importMatch := reImportRemoval.FindStringSubmatch(svcCode)
   473  			buf.WriteString(importMatch[0] + "\n\n")
   474  			buf.WriteString(preamble + "\n\n")
   475  		}
   476  		svcCode = removeImports(svcCode)
   477  		svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
   478  		svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1)
   479  		svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1)
   480  		buf.WriteString(svcCode + "\n\n")
   481  
   482  		apiCode := removeImports(suite.API.APIGoCode())
   483  		apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
   484  		apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
   485  		apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
   486  		buf.WriteString(apiCode + "\n\n")
   487  
   488  		innerBuf.WriteString(suite.TestSuite() + "\n")
   489  	}
   490  
   491  	return buf.String() + innerBuf.String()
   492  }
   493  
   494  // findMember searches the shape for the member with the matching key name.
   495  func findMember(shape *api.Shape, key string) string {
   496  	for actualKey := range shape.MemberRefs {
   497  		if strings.EqualFold(key, actualKey) {
   498  			return actualKey
   499  		}
   500  	}
   501  	return ""
   502  }
   503  
   504  // GenerateAssertions builds assertions for a shape based on its type.
   505  //
   506  // The shape's recursive values also will have assertions generated for them.
   507  func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string {
   508  	if shape == nil {
   509  		return ""
   510  	}
   511  	switch t := out.(type) {
   512  	case map[string]interface{}:
   513  		keys := util.SortedKeys(t)
   514  
   515  		code := ""
   516  		if shape.Type == "map" {
   517  			for _, k := range keys {
   518  				v := t[k]
   519  				s := shape.ValueRef.Shape
   520  				code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]")
   521  			}
   522  		} else if shape.Type == "jsonvalue" {
   523  			code += fmt.Sprintf("reflect.DeepEqual(%s, map[string]interface{}%s)\n", prefix, walkMap(out.(map[string]interface{})))
   524  		} else {
   525  			for _, k := range keys {
   526  				v := t[k]
   527  				m := findMember(shape, k)
   528  				s := shape.MemberRefs[m].Shape
   529  				code += GenerateAssertions(v, s, prefix+"."+m+"")
   530  			}
   531  		}
   532  		return code
   533  	case []interface{}:
   534  		code := ""
   535  		for i, v := range t {
   536  			s := shape.MemberRef.Shape
   537  			code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]")
   538  		}
   539  		return code
   540  	default:
   541  		switch shape.Type {
   542  		case "timestamp":
   543  			return fmtAssertEqual(
   544  				fmt.Sprintf("time.Unix(%#v, 0).UTC().String()", out),
   545  				fmt.Sprintf("%s.UTC().String()", prefix),
   546  			)
   547  		case "blob":
   548  			return fmtAssertEqual(
   549  				fmt.Sprintf("%#v", out),
   550  				fmt.Sprintf("string(%s)", prefix),
   551  			)
   552  		case "integer", "long":
   553  			return fmtAssertEqual(
   554  				fmt.Sprintf("int64(%#v)", out),
   555  				fmt.Sprintf("*%s", prefix),
   556  			)
   557  		default:
   558  			if !reflect.ValueOf(out).IsValid() {
   559  				return fmtAssertNil(prefix)
   560  			}
   561  			return fmtAssertEqual(
   562  				fmt.Sprintf("%#v", out),
   563  				fmt.Sprintf("*%s", prefix),
   564  			)
   565  		}
   566  	}
   567  }
   568  
   569  func getType(t string) uint {
   570  	switch t {
   571  	case "Input":
   572  		return TestSuiteTypeInput
   573  	case "Output":
   574  		return TestSuiteTypeOutput
   575  	default:
   576  		panic("Invalid type for test suite")
   577  	}
   578  }
   579  
   580  func main() {
   581  	if len(os.Getenv("AWS_SDK_CODEGEN_DEBUG")) != 0 {
   582  		api.LogDebug(os.Stdout)
   583  	}
   584  
   585  	fmt.Println("Generating test suite", os.Args[1:])
   586  	out := generateTestSuite(os.Args[1])
   587  	if len(os.Args) == 3 {
   588  		f, err := os.Create(os.Args[2])
   589  		defer f.Close()
   590  		if err != nil {
   591  			panic(err)
   592  		}
   593  		f.WriteString(util.GoFmt(out))
   594  		f.Close()
   595  
   596  		c := exec.Command("gofmt", "-s", "-w", os.Args[2])
   597  		if err := c.Run(); err != nil {
   598  			panic(err)
   599  		}
   600  	} else {
   601  		fmt.Println(out)
   602  	}
   603  }