github.com/aavshr/aws-sdk-go@v1.41.3/private/model/api/eventstream_tmpl_readertests.go (about)

     1  //go:build codegen
     2  // +build codegen
     3  
     4  package api
     5  
     6  import (
     7  	"text/template"
     8  )
     9  
    10  var eventStreamReaderTestTmpl = template.Must(
    11  	template.New("eventStreamReaderTestTmpl").Funcs(template.FuncMap{
    12  		"ValueForType":             valueForType,
    13  		"HasNonBlobPayloadMembers": eventHasNonBlobPayloadMembers,
    14  		"EventHeaderValueForType":  setEventHeaderValueForType,
    15  		"Map":                      templateMap,
    16  		"OptionalAddInt": func(do bool, a, b int) int {
    17  			if !do {
    18  				return a
    19  			}
    20  			return a + b
    21  		},
    22  		"HasNonEventStreamMember": func(s *Shape) bool {
    23  			for _, ref := range s.MemberRefs {
    24  				if !ref.Shape.IsEventStream {
    25  					return true
    26  				}
    27  			}
    28  			return false
    29  		},
    30  	}).Parse(`
    31  {{ range $opName, $op := $.Operations }}
    32  	{{ if $op.EventStreamAPI }}
    33  		{{ if  $op.EventStreamAPI.OutputStream }}
    34  			{{ template "event stream outputStream tests" $op.EventStreamAPI }}
    35  		{{ end }}
    36  	{{ end }}
    37  {{ end }}
    38  
    39  type loopReader struct {
    40  	source *bytes.Reader
    41  }
    42  
    43  func (c *loopReader) Read(p []byte) (int, error) {
    44  	if c.source.Len() == 0 {
    45  		c.source.Seek(0, 0)
    46  	}
    47  
    48  	return c.source.Read(p)
    49  }
    50  
    51  {{ define "event stream outputStream tests" }}
    52  	func Test{{ $.Operation.ExportedName }}_Read(t *testing.T) {
    53  		expectEvents, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
    54  		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
    55  			eventstreamtest.ServeEventStream{
    56  				T:      t,
    57  				Events: eventMsgs,
    58  			},
    59  			true,
    60  		)
    61  		if err != nil {
    62  			t.Fatalf("expect no error, %v", err)
    63  		}
    64  		defer cleanupFn()
    65  
    66  		svc := New(sess)
    67  		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
    68  		if err != nil {
    69  			t.Fatalf("expect no error got, %v", err)
    70  		}
    71  		defer resp.GetStream().Close()
    72  
    73  		{{- if eq $.Operation.API.Metadata.Protocol "json" }}
    74  			{{- if HasNonEventStreamMember $.Operation.OutputRef.Shape }}
    75  				expectResp := expectEvents[0].(*{{ $.Operation.OutputRef.Shape.ShapeName }})
    76  				{{- range $name, $ref := $.Operation.OutputRef.Shape.MemberRefs }}
    77  					{{- if not $ref.Shape.IsEventStream }}
    78  						if e, a := expectResp.{{ $name }}, resp.{{ $name }}; !reflect.DeepEqual(e,a) {
    79  							t.Errorf("expect %v, got %v", e, a)
    80  						}
    81  					{{- end }}
    82  				{{- end }}
    83  			{{- end }}
    84  			// Trim off response output type pseudo event so only event messages remain.
    85  			expectEvents = expectEvents[1:]
    86  		{{ end }}
    87  
    88  		var i int
    89  		for event := range resp.GetStream().Events() {
    90  			if event == nil {
    91  				t.Errorf("%d, expect event, got nil", i)
    92  			}
    93  			if e, a := expectEvents[i], event; !reflect.DeepEqual(e, a) {
    94  				t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a)
    95  			}
    96  			i++
    97  		}
    98  
    99  		if err := resp.GetStream().Err(); err != nil {
   100  			t.Errorf("expect no error, %v", err)
   101  		}
   102  	}
   103  
   104  	func Test{{ $.Operation.ExportedName }}_ReadClose(t *testing.T) {
   105  		_, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
   106  		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
   107  			eventstreamtest.ServeEventStream{
   108  				T:      t,
   109  				Events: eventMsgs,
   110  			},
   111  			true,
   112  		)
   113  		if err != nil {
   114  			t.Fatalf("expect no error, %v", err)
   115  		}
   116  		defer cleanupFn()
   117  
   118  		svc := New(sess)
   119  		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
   120  		if err != nil {
   121  			t.Fatalf("expect no error got, %v", err)
   122  		}
   123  
   124  		{{ if gt (len $.OutputStream.Events) 0 -}}
   125  			// Assert calling Err before close does not close the stream.
   126  			resp.GetStream().Err()
   127  			select {
   128  			case _, ok := <-resp.GetStream().Events():
   129  				if !ok {
   130  					t.Fatalf("expect stream not to be closed, but was")
   131  				}
   132  			default:
   133  			}
   134  		{{- end }}
   135  
   136  		resp.GetStream().Close()
   137  		<-resp.GetStream().Events()
   138  
   139  		if err := resp.GetStream().Err(); err != nil {
   140  			t.Errorf("expect no error, %v", err)
   141  		}
   142  	}
   143  
   144  	func Test{{ $.Operation.ExportedName }}_ReadUnknownEvent(t *testing.T) {
   145  		expectEvents, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
   146  
   147  		{{- if eq $.Operation.API.Metadata.Protocol "json" }}
   148  			eventOffset := 1
   149  		{{- else }}
   150  			var eventOffset int
   151  		{{- end }}
   152  
   153  		unknownEvent := eventstream.Message{
   154  			Headers: eventstream.Headers{
   155  				eventstreamtest.EventMessageTypeHeader,
   156  				{
   157  					Name:  eventstreamapi.EventTypeHeader,
   158  					Value: eventstream.StringValue("UnknownEventName"),
   159  				},
   160  			},
   161  			Payload: []byte("some unknown event"),
   162  		}
   163  
   164  		eventMsgs = append(eventMsgs[:eventOffset],
   165  			append([]eventstream.Message{unknownEvent}, eventMsgs[eventOffset:]...)...)
   166  
   167  		expectEvents = append(expectEvents[:eventOffset],
   168  			append([]{{ $.OutputStream.Name }}Event{
   169  					&{{ $.OutputStream.StreamUnknownEventName }}{
   170  						Type: "UnknownEventName",
   171  						Message: unknownEvent,
   172  					},
   173  				},
   174  				expectEvents[eventOffset:]...)...)
   175  
   176  		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
   177  			eventstreamtest.ServeEventStream{
   178  				T:      t,
   179  				Events: eventMsgs,
   180  			},
   181  			true,
   182  		)
   183  		if err != nil {
   184  			t.Fatalf("expect no error, %v", err)
   185  		}
   186  		defer cleanupFn()
   187  
   188  		svc := New(sess)
   189  		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
   190  		if err != nil {
   191  			t.Fatalf("expect no error got, %v", err)
   192  		}
   193  		defer resp.GetStream().Close()
   194  
   195  		{{- if eq $.Operation.API.Metadata.Protocol "json" }}
   196  			// Trim off response output type pseudo event so only event messages remain.
   197  			expectEvents = expectEvents[1:]
   198  		{{ end }}
   199  
   200  		var i int
   201  		for event := range resp.GetStream().Events() {
   202  			if event == nil {
   203  				t.Errorf("%d, expect event, got nil", i)
   204  			}
   205  			if e, a := expectEvents[i], event; !reflect.DeepEqual(e, a) {
   206  				t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a)
   207  			}
   208  			i++
   209  		}
   210  
   211  		if err := resp.GetStream().Err(); err != nil {
   212  			t.Errorf("expect no error, %v", err)
   213  		}
   214  	}
   215  
   216  	func Benchmark{{ $.Operation.ExportedName }}_Read(b *testing.B) {
   217  		_, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
   218  		var buf bytes.Buffer
   219  		encoder := eventstream.NewEncoder(&buf)
   220  		for _, msg := range eventMsgs {
   221  			if err := encoder.Encode(msg); err != nil {
   222  				b.Fatalf("failed to encode message, %v", err)
   223  			}
   224  		}
   225  		stream := &loopReader{source: bytes.NewReader(buf.Bytes())}
   226  
   227  		sess := unit.Session
   228  		svc := New(sess, &aws.Config{
   229  			Endpoint:               aws.String("https://example.com"),
   230  			DisableParamValidation: aws.Bool(true),
   231  		})
   232  		svc.Handlers.Send.Swap(corehandlers.SendHandler.Name,
   233  			request.NamedHandler{Name: "mockSend",
   234  				Fn: func(r *request.Request) {
   235  					r.HTTPResponse = &http.Response{
   236  						Status:     "200 OK",
   237  						StatusCode: 200,
   238  						Header:     http.Header{},
   239  						Body:       ioutil.NopCloser(stream),
   240  					}
   241  				},
   242  			},
   243  		)
   244  
   245  		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
   246  		if err != nil {
   247  			b.Fatalf("failed to create request, %v", err)
   248  		}
   249  		defer resp.GetStream().Close()
   250  		b.ResetTimer()
   251  
   252  		for i := 0; i < b.N; i++ {
   253  			if err = resp.GetStream().Err(); err != nil {
   254  				b.Fatalf("expect no error, got %v", err)
   255  			}
   256  			event := <-resp.GetStream().Events()
   257  			if event == nil {
   258  				b.Fatalf("expect event, got nil, %v, %d", resp.GetStream().Err(), i)
   259  			}
   260  		}
   261  	}
   262  
   263  	func mock{{ $.Operation.ExportedName }}ReadEvents() (
   264  		[]{{ $.OutputStream.Name }}Event,
   265  		[]eventstream.Message,
   266  	) {
   267  		expectEvents := []{{ $.OutputStream.Name }}Event {
   268  			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
   269  				{{- template "set event type" $.Operation.OutputRef.Shape }}
   270  			{{- end }}
   271  			{{- range $_, $event := $.OutputStream.Events }}
   272  				{{- template "set event type" $event.Shape }}
   273  			{{- end }}
   274  		}
   275  
   276  		var marshalers request.HandlerList
   277  		marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
   278  		payloadMarshaler := protocol.HandlerPayloadMarshal{
   279  			Marshalers: marshalers,
   280  		}
   281  		_ = payloadMarshaler
   282  
   283  		eventMsgs := []eventstream.Message{
   284  			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
   285  				{{- template "set event message" Map "idx" 0 "parentShape" $.Operation.OutputRef.Shape "eventName" "initial-response" }}
   286  			{{- end }}
   287  			{{- range $idx, $event := $.OutputStream.Events }}
   288  				{{- $offsetIdx := OptionalAddInt (eq $.Operation.API.Metadata.Protocol "json") $idx 1 }}
   289  				{{- template "set event message" Map "idx" $offsetIdx "parentShape" $event.Shape "eventName" $event.Name }}
   290  			{{- end }}
   291  		}
   292  
   293  		return expectEvents, eventMsgs
   294  	}
   295  
   296  	{{- if $.OutputStream.Exceptions }}
   297  		func Test{{ $.Operation.ExportedName }}_ReadException(t *testing.T) {
   298  			expectEvents := []{{ $.OutputStream.Name }}Event {
   299  				{{- if eq $.Operation.API.Metadata.Protocol "json" }}
   300  					{{- template "set event type" $.Operation.OutputRef.Shape }}
   301  				{{- end }}
   302  
   303  				{{- $exception := index $.OutputStream.Exceptions 0 }}
   304  				{{- template "set event type" $exception.Shape }}
   305  			}
   306  
   307  			var marshalers request.HandlerList
   308  			marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
   309  			payloadMarshaler := protocol.HandlerPayloadMarshal{
   310  				Marshalers: marshalers,
   311  			}
   312  
   313  			eventMsgs := []eventstream.Message{
   314  				{{- if eq $.Operation.API.Metadata.Protocol "json" }}
   315  					{{- template "set event message" Map "idx" 0 "parentShape" $.Operation.OutputRef.Shape "eventName" "initial-response" }}
   316  				{{- end }}
   317  
   318  				{{- $offsetIdx := OptionalAddInt (eq $.Operation.API.Metadata.Protocol "json") 0 1 }}
   319  				{{- $exception := index $.OutputStream.Exceptions 0 }}
   320  				{{- template "set event message" Map "idx" $offsetIdx "parentShape" $exception.Shape "eventName" $exception.Name }}
   321  			}
   322  
   323  			sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
   324  				eventstreamtest.ServeEventStream{
   325  					T:      t,
   326  					Events: eventMsgs,
   327  				},
   328  				true,
   329  			)
   330  			if err != nil {
   331  				t.Fatalf("expect no error, %v", err)
   332  			}
   333  			defer cleanupFn()
   334  
   335  			svc := New(sess)
   336  			resp, err := svc.{{ $.Operation.ExportedName }}(nil)
   337  			if err != nil {
   338  				t.Fatalf("expect no error got, %v", err)
   339  			}
   340  
   341  			defer resp.GetStream().Close()
   342  
   343  			<-resp.GetStream().Events()
   344  
   345  			err = resp.GetStream().Err()
   346  			if err == nil {
   347  				t.Fatalf("expect err, got none")
   348  			}
   349  
   350  			expectErr := {{ ValueForType $exception.Shape nil }}
   351  			aerr, ok := err.(awserr.Error)
   352  			if !ok {
   353  				t.Errorf("expect exception, got %T, %#v", err, err)
   354  			}
   355  			if e, a := expectErr.Code(), aerr.Code(); e != a {
   356  				t.Errorf("expect %v, got %v", e, a)
   357  			}
   358  			if e, a := expectErr.Message(), aerr.Message(); e != a {
   359  				t.Errorf("expect %v, got %v", e, a)
   360  			}
   361  
   362  			if e, a := expectErr, aerr; !reflect.DeepEqual(e, a) {
   363  				t.Errorf("expect error %+#v, got %+#v", e, a)
   364  			}
   365  		}
   366  
   367  		{{- range $_, $exception := $.OutputStream.Exceptions }}
   368  			var _ awserr.Error = (*{{ $exception.Shape.ShapeName }})(nil)
   369  		{{- end }}
   370  
   371  	{{ end }}
   372  {{ end }}
   373  
   374  {{/* Params: *Shape */}}
   375  {{ define "set event type" }}
   376  	&{{ $.ShapeName }}{
   377  		{{- if $.Exception }}
   378  			RespMetadata: protocol.ResponseMetadata{
   379  				StatusCode: 200,
   380  			},
   381  		{{- end }}
   382  		{{- range $memName, $memRef := $.MemberRefs }}
   383  			{{- if not $memRef.Shape.IsEventStream }}
   384  				{{ $memName }}: {{ ValueForType $memRef.Shape nil }},
   385  			{{- end }}
   386  		{{- end }}
   387  	},
   388  {{- end }}
   389  
   390  {{/* Params: idx:int, parentShape:*Shape, eventName:string */}}
   391  {{ define "set event message" }}
   392  	{
   393  		Headers: eventstream.Headers{
   394  			{{- if $.parentShape.Exception }}
   395  				eventstreamtest.EventExceptionTypeHeader,
   396  				{
   397  					Name:  eventstreamapi.ExceptionTypeHeader,
   398  					Value: eventstream.StringValue("{{ $.eventName }}"),
   399  				},
   400  			{{- else }}
   401  				eventstreamtest.EventMessageTypeHeader,
   402  				{
   403  					Name:  eventstreamapi.EventTypeHeader,
   404  					Value: eventstream.StringValue("{{ $.eventName }}"),
   405  				},
   406  			{{- end }}
   407  			{{- range $memName, $memRef := $.parentShape.MemberRefs }}
   408  				{{- template "set event message header" Map "idx" $.idx "parentShape" $.parentShape "memName" $memName "memRef" $memRef }}
   409  			{{- end }}
   410  		},
   411  		{{- template "set event message payload" Map "idx" $.idx "parentShape" $.parentShape }}
   412  	},
   413  {{- end }}
   414  
   415  {{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
   416  {{ define "set event message header" }}
   417  	{{- if $.memRef.IsEventHeader }}
   418  		{
   419  			Name: "{{ $.memName }}",
   420  			{{- $shapeValueVar := printf "expectEvents[%d].(%s).%s" $.idx $.parentShape.GoType $.memName }}
   421  			Value: {{ EventHeaderValueForType $.memRef.Shape $shapeValueVar }},
   422  		},
   423  	{{- end }}
   424  {{- end }}
   425  
   426  {{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
   427  {{ define "set event message payload" }}
   428  	{{- $payloadMemName := $.parentShape.PayloadRefName }}
   429  	{{- if HasNonBlobPayloadMembers $.parentShape }}
   430  		Payload: eventstreamtest.MarshalEventPayload(payloadMarshaler, expectEvents[{{ $.idx }}]),
   431  	{{- else if $payloadMemName }}
   432  		{{- $shapeType := (index $.parentShape.MemberRefs $payloadMemName).Shape.Type }}
   433  		{{- if eq $shapeType "blob" }}
   434  			Payload: expectEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }},
   435  		{{- else if eq $shapeType "string" }}
   436  			Payload: []byte(*expectEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }}),
   437  		{{- end }}
   438  	{{- end }}
   439  {{- end }}
   440  `))