github.com/aavshr/aws-sdk-go@v1.41.3/private/model/cli/gen-protocol-tests/main.go (about) 1 //go:build codegen 2 // +build codegen 3 4 package main 5 6 import ( 7 "bytes" 8 "encoding/json" 9 "fmt" 10 "net/url" 11 "os" 12 "os/exec" 13 "reflect" 14 "regexp" 15 "sort" 16 "strconv" 17 "strings" 18 "text/template" 19 20 "github.com/aavshr/aws-sdk-go/private/model/api" 21 "github.com/aavshr/aws-sdk-go/private/util" 22 ) 23 24 // TestSuiteTypeInput input test 25 // TestSuiteTypeInput output test 26 const ( 27 TestSuiteTypeInput = iota 28 TestSuiteTypeOutput 29 ) 30 31 type testSuite struct { 32 *api.API 33 Description string 34 ClientEndpoint string 35 Cases []testCase 36 Type uint 37 title string 38 } 39 40 func (s *testSuite) UnmarshalJSON(p []byte) error { 41 type stub testSuite 42 43 var v stub 44 if err := json.Unmarshal(p, &v); err != nil { 45 return err 46 } 47 48 if len(v.ClientEndpoint) == 0 { 49 v.ClientEndpoint = "https://test" 50 } 51 for i := 0; i < len(v.Cases); i++ { 52 if len(v.Cases[i].InputTest.Host) == 0 { 53 v.Cases[i].InputTest.Host = "test" 54 } 55 if len(v.Cases[i].InputTest.URI) == 0 { 56 v.Cases[i].InputTest.URI = "/" 57 } 58 } 59 60 *s = testSuite(v) 61 return nil 62 } 63 64 type testCase struct { 65 TestSuite *testSuite 66 Given *api.Operation 67 Params interface{} `json:",omitempty"` 68 Data interface{} `json:"result,omitempty"` 69 InputTest testExpectation `json:"serialized"` 70 OutputTest testExpectation `json:"response"` 71 } 72 73 type testExpectation struct { 74 Body string 75 Host string 76 URI string 77 Headers map[string]string 78 ForbidHeaders []string 79 JSONValues map[string]string 80 StatusCode uint `json:"status_code"` 81 } 82 83 const preamble = ` 84 var _ bytes.Buffer // always import bytes 85 var _ http.Request 86 var _ json.Marshaler 87 var _ time.Time 88 var _ xmlutil.XMLNode 89 var _ xml.Attr 90 var _ = ioutil.Discard 91 var _ = util.Trim("") 92 var _ = url.Values{} 93 var _ = io.EOF 94 var _ = aws.String 95 var _ = fmt.Println 96 var _ = reflect.Value{} 97 98 func init() { 99 protocol.RandReader = &awstesting.ZeroReader{} 100 } 101 ` 102 103 var reStripSpace = regexp.MustCompile(`\s(\w)`) 104 105 var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`) 106 107 func removeImports(code string) string { 108 return reImportRemoval.ReplaceAllString(code, "") 109 } 110 111 var extraImports = []string{ 112 "bytes", 113 "encoding/json", 114 "encoding/xml", 115 "fmt", 116 "io", 117 "io/ioutil", 118 "net/http", 119 "testing", 120 "time", 121 "reflect", 122 "net/url", 123 "", 124 "github.com/aavshr/aws-sdk-go/awstesting", 125 "github.com/aavshr/aws-sdk-go/awstesting/unit", 126 "github.com/aavshr/aws-sdk-go/private/protocol", 127 "github.com/aavshr/aws-sdk-go/private/protocol/xml/xmlutil", 128 "github.com/aavshr/aws-sdk-go/private/util", 129 } 130 131 func addImports(code string) string { 132 importNames := make([]string, len(extraImports)) 133 for i, n := range extraImports { 134 if n != "" { 135 importNames[i] = fmt.Sprintf("%q", n) 136 } 137 } 138 139 str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)") 140 return str 141 } 142 143 func (t *testSuite) TestSuite() string { 144 var buf bytes.Buffer 145 146 t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string { 147 return strings.ToUpper(x[1:]) 148 }) 149 t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "") 150 151 for idx, c := range t.Cases { 152 c.TestSuite = t 153 buf.WriteString(c.TestCase(idx) + "\n") 154 } 155 return buf.String() 156 } 157 158 var tplInputTestCase = template.Must(template.New("inputcase"). 159 Funcs(template.FuncMap{ 160 "stringsEqualFold": strings.EqualFold, 161 }). 162 Parse(` 163 func Test{{ .OpName }}(t *testing.T) { 164 svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("{{ .TestCase.TestSuite.ClientEndpoint }}")}) 165 {{ if ne .ParamsString "" }}input := {{ .ParamsString }} 166 {{ range $k, $v := .JSONValues -}} 167 input.{{ $k }} = {{ $v }} 168 {{ end -}} 169 req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }} 170 r := req.HTTPRequest 171 172 // build request 173 req.Build() 174 if req.Error != nil { 175 t.Errorf("expect no error, got %v", req.Error) 176 } 177 req.Sign() 178 if req.Error != nil { 179 t.Errorf("expect no error, got %v", req.Error) 180 } 181 182 {{- if ne .TestCase.InputTest.Body "" }} 183 184 // assert body 185 if r.Body == nil { 186 t.Errorf("expect body not to be nil") 187 } 188 {{ .BodyAssertions }} 189 190 if e, a := int64(len(body)), r.ContentLength; e != a { 191 t.Errorf("expect serialized body length to match %v ContentLength, got %v", e, a) 192 } 193 {{- end }} 194 195 // assert URL 196 awstesting.AssertURL(t, "https://{{ .TestCase.InputTest.Host }}{{ .TestCase.InputTest.URI }}", r.URL.String()) 197 198 {{- if .TestCase.InputTest.Headers }} 199 200 // assert headers 201 {{- range $k, $v := .TestCase.InputTest.Headers }} 202 {{- if not (stringsEqualFold $k "Content-Length") }} 203 if e, a := "{{ $v }}", r.Header.Get("{{ $k }}"); e != a { 204 t.Errorf("expect {{ $k }} %v header value, got %v", e, a) 205 } 206 {{- end }} 207 {{- end }} 208 {{- end }} 209 210 211 {{- if .TestCase.InputTest.ForbidHeaders }} 212 213 // assert exclude headers 214 {{- range $_, $k := .TestCase.InputTest.ForbidHeaders }} 215 {{- if eq $k "Content-Length" }} 216 if v := r.ContentLength; v > 0 { 217 t.Errorf("expect no content-length, got %v", v) 218 } 219 {{- end }} 220 if v := r.Header.Get("{{ $k }}"); v != "" { 221 t.Errorf("expect not to have {{ $k }} header, got with value %v", v) 222 } 223 {{- end }} 224 {{- end }} 225 } 226 `)) 227 228 type tplInputTestCaseData struct { 229 TestCase *testCase 230 JSONValues map[string]string 231 OpName, ParamsString string 232 } 233 234 func (t tplInputTestCaseData) BodyAssertions() string { 235 code := &bytes.Buffer{} 236 protocol := t.TestCase.TestSuite.API.Metadata.Protocol 237 238 // Extract the body bytes 239 fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)") 240 241 // Generate the body verification code 242 expectedBody := util.Trim(t.TestCase.InputTest.Body) 243 switch protocol { 244 case "ec2", "query": 245 fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))", 246 expectedBody) 247 case "rest-xml": 248 if strings.HasPrefix(expectedBody, "<") { 249 fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)))", 250 expectedBody) 251 } else { 252 code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))")) 253 } 254 case "json", "jsonrpc", "rest-json": 255 if strings.HasPrefix(expectedBody, "{") { 256 fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))", 257 expectedBody) 258 } else { 259 code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))")) 260 } 261 default: 262 code.WriteString(fmtAssertEqual(expectedBody, "util.Trim(string(body))")) 263 } 264 265 return code.String() 266 } 267 268 func fmtAssertEqual(e, a string) string { 269 const format = `if e, a := %s, %s; e != a { 270 t.Errorf("expect %%v, got %%v", e, a) 271 } 272 ` 273 274 return fmt.Sprintf(format, e, a) 275 } 276 277 func fmtAssertNil(v string) string { 278 const format = `if e := %s; e != nil { 279 t.Errorf("expect nil, got %%v", e) 280 } 281 ` 282 283 return fmt.Sprintf(format, v) 284 } 285 286 var tplOutputTestCase = template.Must(template.New("outputcase").Parse(` 287 func Test{{ .OpName }}(t *testing.T) { 288 svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")}) 289 290 buf := bytes.NewReader([]byte({{ .Body }})) 291 req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil) 292 req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}} 293 294 // set headers 295 {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}") 296 {{ end }} 297 298 // unmarshal response 299 req.Handlers.UnmarshalMeta.Run(req) 300 req.Handlers.Unmarshal.Run(req) 301 if req.Error != nil { 302 t.Errorf("expect not error, got %v", req.Error) 303 } 304 305 // assert response 306 if out == nil { 307 t.Errorf("expect not to be nil") 308 } 309 {{ .Assertions }} 310 } 311 `)) 312 313 type tplOutputTestCaseData struct { 314 TestCase *testCase 315 Body, OpName, Assertions string 316 } 317 318 func (i *testCase) TestCase(idx int) string { 319 var buf bytes.Buffer 320 321 opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1) 322 323 if i.TestSuite.Type == TestSuiteTypeInput { // input test 324 // query test should sort body as form encoded values 325 switch i.TestSuite.API.Metadata.Protocol { 326 case "query", "ec2": 327 m, _ := url.ParseQuery(i.InputTest.Body) 328 i.InputTest.Body = m.Encode() 329 case "rest-xml": 330 // Nothing to do 331 case "json", "rest-json": 332 // Nothing to do 333 } 334 335 jsonValues := buildJSONValues(i.Given.InputRef.Shape) 336 var params interface{} 337 if m, ok := i.Params.(map[string]interface{}); ok { 338 paramsMap := map[string]interface{}{} 339 for k, v := range m { 340 if _, ok := jsonValues[k]; !ok { 341 paramsMap[k] = v 342 } else { 343 if i.InputTest.JSONValues == nil { 344 i.InputTest.JSONValues = map[string]string{} 345 } 346 i.InputTest.JSONValues[k] = serializeJSONValue(v.(map[string]interface{})) 347 } 348 } 349 params = paramsMap 350 } else { 351 params = i.Params 352 } 353 input := tplInputTestCaseData{ 354 TestCase: i, 355 OpName: strings.ToUpper(opName[0:1]) + opName[1:], 356 ParamsString: api.ParamsStructFromJSON(params, i.Given.InputRef.Shape, false), 357 JSONValues: i.InputTest.JSONValues, 358 } 359 360 if err := tplInputTestCase.Execute(&buf, input); err != nil { 361 panic(err) 362 } 363 } else if i.TestSuite.Type == TestSuiteTypeOutput { 364 output := tplOutputTestCaseData{ 365 TestCase: i, 366 Body: fmt.Sprintf("%q", i.OutputTest.Body), 367 OpName: strings.ToUpper(opName[0:1]) + opName[1:], 368 Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"), 369 } 370 371 if err := tplOutputTestCase.Execute(&buf, output); err != nil { 372 panic(err) 373 } 374 } 375 376 return buf.String() 377 } 378 379 func serializeJSONValue(m map[string]interface{}) string { 380 str := "aws.JSONValue" 381 str += walkMap(m) 382 return str 383 } 384 385 func walkMap(m map[string]interface{}) string { 386 str := "{" 387 for k, v := range m { 388 str += fmt.Sprintf("%q:", k) 389 switch v.(type) { 390 case bool: 391 str += fmt.Sprintf("%t,\n", v.(bool)) 392 case string: 393 str += fmt.Sprintf("%q,\n", v.(string)) 394 case int: 395 str += fmt.Sprintf("%d,\n", v.(int)) 396 case float64: 397 str += fmt.Sprintf("%f,\n", v.(float64)) 398 case map[string]interface{}: 399 str += walkMap(v.(map[string]interface{})) 400 } 401 } 402 str += "}" 403 return str 404 } 405 406 func buildJSONValues(shape *api.Shape) map[string]struct{} { 407 keys := map[string]struct{}{} 408 for key, field := range shape.MemberRefs { 409 if field.JSONValue { 410 keys[key] = struct{}{} 411 } 412 } 413 return keys 414 } 415 416 // generateTestSuite generates a protocol test suite for a given configuration 417 // JSON protocol test file. 418 func generateTestSuite(filename string) string { 419 inout := "Input" 420 if strings.Contains(filename, "output/") { 421 inout = "Output" 422 } 423 424 var suites []testSuite 425 f, err := os.Open(filename) 426 if err != nil { 427 panic(err) 428 } 429 430 err = json.NewDecoder(f).Decode(&suites) 431 if err != nil { 432 panic(err) 433 } 434 435 var buf bytes.Buffer 436 buf.WriteString("// Code generated by models/protocol_tests/generate.go. DO NOT EDIT.\n\n") 437 buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n") 438 439 var innerBuf bytes.Buffer 440 innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n") 441 442 for i, suite := range suites { 443 svcPrefix := inout + "Service" + strconv.Itoa(i+1) 444 suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest" 445 suite.API.Operations = map[string]*api.Operation{} 446 for idx, c := range suite.Cases { 447 c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1) 448 suite.API.Operations[c.Given.ExportedName] = c.Given 449 } 450 451 suite.Type = getType(inout) 452 suite.API.NoInitMethods = true // don't generate init methods 453 suite.API.NoStringerMethods = true // don't generate stringer methods 454 suite.API.NoConstServiceNames = true // don't generate service names 455 suite.API.Setup() 456 suite.API.Metadata.EndpointPrefix = suite.API.PackageName() 457 suite.API.Metadata.EndpointsID = suite.API.Metadata.EndpointPrefix 458 459 // Sort in order for deterministic test generation 460 names := make([]string, 0, len(suite.API.Shapes)) 461 for n := range suite.API.Shapes { 462 names = append(names, n) 463 } 464 sort.Strings(names) 465 for _, name := range names { 466 s := suite.API.Shapes[name] 467 s.Rename(svcPrefix + "TestShape" + name) 468 } 469 470 svcCode := addImports(suite.API.ServiceGoCode()) 471 if i == 0 { 472 importMatch := reImportRemoval.FindStringSubmatch(svcCode) 473 buf.WriteString(importMatch[0] + "\n\n") 474 buf.WriteString(preamble + "\n\n") 475 } 476 svcCode = removeImports(svcCode) 477 svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1) 478 svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1) 479 svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1) 480 buf.WriteString(svcCode + "\n\n") 481 482 apiCode := removeImports(suite.API.APIGoCode()) 483 apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1) 484 apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1) 485 apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1) 486 buf.WriteString(apiCode + "\n\n") 487 488 innerBuf.WriteString(suite.TestSuite() + "\n") 489 } 490 491 return buf.String() + innerBuf.String() 492 } 493 494 // findMember searches the shape for the member with the matching key name. 495 func findMember(shape *api.Shape, key string) string { 496 for actualKey := range shape.MemberRefs { 497 if strings.EqualFold(key, actualKey) { 498 return actualKey 499 } 500 } 501 return "" 502 } 503 504 // GenerateAssertions builds assertions for a shape based on its type. 505 // 506 // The shape's recursive values also will have assertions generated for them. 507 func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string { 508 if shape == nil { 509 return "" 510 } 511 switch t := out.(type) { 512 case map[string]interface{}: 513 keys := util.SortedKeys(t) 514 515 code := "" 516 if shape.Type == "map" { 517 for _, k := range keys { 518 v := t[k] 519 s := shape.ValueRef.Shape 520 code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]") 521 } 522 } else if shape.Type == "jsonvalue" { 523 code += fmt.Sprintf("reflect.DeepEqual(%s, map[string]interface{}%s)\n", prefix, walkMap(out.(map[string]interface{}))) 524 } else { 525 for _, k := range keys { 526 v := t[k] 527 m := findMember(shape, k) 528 s := shape.MemberRefs[m].Shape 529 code += GenerateAssertions(v, s, prefix+"."+m+"") 530 } 531 } 532 return code 533 case []interface{}: 534 code := "" 535 for i, v := range t { 536 s := shape.MemberRef.Shape 537 code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]") 538 } 539 return code 540 default: 541 switch shape.Type { 542 case "timestamp": 543 return fmtAssertEqual( 544 fmt.Sprintf("time.Unix(%#v, 0).UTC().String()", out), 545 fmt.Sprintf("%s.UTC().String()", prefix), 546 ) 547 case "blob": 548 return fmtAssertEqual( 549 fmt.Sprintf("%#v", out), 550 fmt.Sprintf("string(%s)", prefix), 551 ) 552 case "integer", "long": 553 return fmtAssertEqual( 554 fmt.Sprintf("int64(%#v)", out), 555 fmt.Sprintf("*%s", prefix), 556 ) 557 default: 558 if !reflect.ValueOf(out).IsValid() { 559 return fmtAssertNil(prefix) 560 } 561 return fmtAssertEqual( 562 fmt.Sprintf("%#v", out), 563 fmt.Sprintf("*%s", prefix), 564 ) 565 } 566 } 567 } 568 569 func getType(t string) uint { 570 switch t { 571 case "Input": 572 return TestSuiteTypeInput 573 case "Output": 574 return TestSuiteTypeOutput 575 default: 576 panic("Invalid type for test suite") 577 } 578 } 579 580 func main() { 581 if len(os.Getenv("AWS_SDK_CODEGEN_DEBUG")) != 0 { 582 api.LogDebug(os.Stdout) 583 } 584 585 fmt.Println("Generating test suite", os.Args[1:]) 586 out := generateTestSuite(os.Args[1]) 587 if len(os.Args) == 3 { 588 f, err := os.Create(os.Args[2]) 589 defer f.Close() 590 if err != nil { 591 panic(err) 592 } 593 f.WriteString(util.GoFmt(out)) 594 f.Close() 595 596 c := exec.Command("gofmt", "-s", "-w", os.Args[2]) 597 if err := c.Run(); err != nil { 598 panic(err) 599 } 600 } else { 601 fmt.Println(out) 602 } 603 }