github.com/aavshr/aws-sdk-go@v1.41.3/private/model/api/eventstream_tmpl_writertests.go (about)

     1  //go:build codegen
     2  // +build codegen
     3  
     4  package api
     5  
     6  import (
     7  	"text/template"
     8  )
     9  
    10  var eventStreamWriterTestTmpl = template.Must(
    11  	template.New("eventStreamWriterTestTmpl").Funcs(template.FuncMap{
    12  		"ValueForType":             valueForType,
    13  		"HasNonBlobPayloadMembers": eventHasNonBlobPayloadMembers,
    14  		"EventHeaderValueForType":  setEventHeaderValueForType,
    15  		"Map":                      templateMap,
    16  		"OptionalAddInt": func(do bool, a, b int) int {
    17  			if !do {
    18  				return a
    19  			}
    20  			return a + b
    21  		},
    22  		"HasNonEventStreamMember": func(s *Shape) bool {
    23  			for _, ref := range s.MemberRefs {
    24  				if !ref.Shape.IsEventStream {
    25  					return true
    26  				}
    27  			}
    28  			return false
    29  		},
    30  	}).Parse(`
    31  {{ range $opName, $op := $.Operations }}
    32  	{{ if $op.EventStreamAPI }}
    33  		{{ if  $op.EventStreamAPI.InputStream }}
    34  			{{ template "event stream inputStream tests" $op.EventStreamAPI }}
    35  		{{ end }}
    36  	{{ end }}
    37  {{ end }}
    38  
    39  {{ define "event stream inputStream tests" }}
    40  	func Test{{ $.Operation.ExportedName }}_Write(t *testing.T) {
    41  		clientEvents, expectedClientEvents := mock{{ $.Operation.ExportedName }}WriteEvents()
    42  
    43  		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
    44  			&eventstreamtest.ServeEventStream{
    45  				T:             t,
    46  				ClientEvents:  expectedClientEvents,
    47  				BiDirectional: true,
    48  			},
    49  			true)
    50  		defer cleanupFn()
    51  
    52  		svc := New(sess)
    53  		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
    54  		if err != nil {
    55  			t.Fatalf("expect no error, got %v", err)
    56  		}
    57  
    58  		stream := resp.GetStream()
    59  
    60  		for _, event := range clientEvents {
    61  			err = stream.Send(context.Background(), event)
    62  			if err != nil {
    63  				t.Fatalf("expect no error, got %v", err)
    64  			}
    65  		}
    66  
    67  		if err := stream.Close(); err != nil {
    68  			t.Errorf("expect no error, got %v", err)
    69  		}
    70  	}
    71  
    72  	func Test{{ $.Operation.ExportedName }}_WriteClose(t *testing.T) {
    73  		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
    74  			eventstreamtest.ServeEventStream{T: t, BiDirectional: true},
    75  			true,
    76  		)
    77  		if err != nil {
    78  			t.Fatalf("expect no error, %v", err)
    79  		}
    80  		defer cleanupFn()
    81  
    82  		svc := New(sess)
    83  		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
    84  		if err != nil {
    85  			t.Fatalf("expect no error got, %v", err)
    86  		}
    87  
    88  		// Assert calling Err before close does not close the stream.
    89  		resp.GetStream().Err()
    90  		{{ $eventShape := index $.InputStream.Events 0 }}
    91  		err = resp.GetStream().Send(context.Background(), &{{ $eventShape.Shape.ShapeName }}{})
    92  		if err != nil {
    93  			t.Fatalf("expect no error, got %v", err)
    94  		}
    95  
    96  		resp.GetStream().Close()
    97  
    98  		if err := resp.GetStream().Err(); err != nil {
    99  			t.Errorf("expect no error, %v", err)
   100  		}
   101  	}
   102  
   103  	func Test{{ $.Operation.ExportedName }}_WriteError(t *testing.T) {
   104  		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
   105  			eventstreamtest.ServeEventStream{
   106  				T:               t,
   107  				BiDirectional:   true,
   108  				ForceCloseAfter: time.Millisecond * 500,
   109  			},
   110  			true,
   111  		)
   112  		if err != nil {
   113  			t.Fatalf("expect no error, %v", err)
   114  		}
   115  		defer cleanupFn()
   116  
   117  		svc := New(sess)
   118  		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
   119  		if err != nil {
   120  			t.Fatalf("expect no error got, %v", err)
   121  		}
   122  
   123  		defer resp.GetStream().Close()
   124  		{{ $eventShape := index $.InputStream.Events 0 }}
   125  		for {
   126  			err = resp.GetStream().Send(context.Background(), &{{ $eventShape.Shape.ShapeName }}{})
   127  			if err != nil {
   128  				if strings.Contains("unable to send event", err.Error()) {
   129  					t.Errorf("expected stream closed error, got %v", err)
   130  				}
   131  				break
   132  			}
   133  		}
   134  	}
   135  
   136  	func Test{{ $.Operation.ExportedName }}_ReadWrite(t *testing.T) {
   137  		expectedServiceEvents, serviceEvents := mock{{ $.Operation.ExportedName }}ReadEvents()
   138  		clientEvents, expectedClientEvents := mock{{ $.Operation.ExportedName }}WriteEvents()
   139  
   140  		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
   141  			&eventstreamtest.ServeEventStream{
   142  				T:             t,
   143  				ClientEvents:  expectedClientEvents,
   144  				Events:        serviceEvents,
   145  				BiDirectional: true,
   146  			},
   147  			true)
   148  		defer cleanupFn()
   149  
   150  		svc := New(sess)
   151  		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
   152  		if err != nil {
   153  			t.Fatalf("expect no error, got %v", err)
   154  		}
   155  
   156  		stream := resp.GetStream()
   157  		defer stream.Close()
   158  
   159  		var wg sync.WaitGroup
   160  
   161  		wg.Add(1)
   162  		go func() {
   163  			defer wg.Done()
   164  			var i int
   165  			for event := range resp.GetStream().Events() {
   166  				if event == nil {
   167  					t.Errorf("%d, expect event, got nil", i)
   168  				}
   169  				if e, a := expectedServiceEvents[i], event; !reflect.DeepEqual(e, a) {
   170  					t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a)
   171  				}
   172  				i++
   173  			}
   174  		}()
   175  
   176  		for _, event := range clientEvents {
   177  			err = stream.Send(context.Background(), event)
   178  			if err != nil {
   179  				t.Errorf("expect no error, got %v", err)
   180  			}
   181  		}
   182  
   183  		resp.GetStream().Close()
   184  
   185  		wg.Wait()
   186  
   187  		if err := resp.GetStream().Err(); err != nil {
   188  			t.Errorf("expect no error, %v", err)
   189  		}
   190  	}
   191  
   192  	func mock{{ $.Operation.ExportedName }}WriteEvents() (
   193  		[]{{ $.InputStream.Name }}Event,
   194  		[]eventstream.Message,
   195  	) {
   196  		inputEvents := []{{ $.InputStream.Name }}Event {
   197  			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
   198  				{{- template "set event type" $.Operation.InputRef.Shape }}
   199  			{{- end }}
   200  			{{- range $_, $event := $.InputStream.Events }}
   201  				{{- template "set event type" $event.Shape }}
   202  			{{- end }}
   203  		}
   204  
   205  		var marshalers request.HandlerList
   206  		marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
   207  		payloadMarshaler := protocol.HandlerPayloadMarshal{
   208  			Marshalers: marshalers,
   209  		}
   210  		_ = payloadMarshaler
   211  
   212  		eventMsgs := []eventstream.Message{
   213  			{{- range $idx, $event := $.InputStream.Events }}
   214  				{{- template "set event message" Map "idx" $idx "parentShape" $event.Shape "eventName" $event.Name }}
   215  			{{- end }}
   216  		}
   217  
   218  		return inputEvents, eventMsgs
   219  	}
   220  {{ end }}
   221  
   222  {{/* Params: *Shape */}}
   223  {{ define "set event type" }}
   224  	&{{ $.ShapeName }}{
   225  		{{- range $memName, $memRef := $.MemberRefs }}
   226  			{{- if not $memRef.Shape.IsEventStream }}
   227  				{{ $memName }}: {{ ValueForType $memRef.Shape nil }},
   228  			{{- end }}
   229  		{{- end }}
   230  	},
   231  {{- end }}
   232  
   233  {{/* Params: idx:int, parentShape:*Shape, eventName:string */}}
   234  {{ define "set event message" }}
   235  	{
   236  		Headers: eventstream.Headers{
   237  			eventstreamtest.EventMessageTypeHeader,
   238  			{{- range $memName, $memRef := $.parentShape.MemberRefs }}
   239  				{{- template "set event message header" Map "idx" $.idx "parentShape" $.parentShape "memName" $memName "memRef" $memRef }}
   240  			{{- end }}
   241  			{
   242  				Name:  eventstreamapi.EventTypeHeader,
   243  				Value: eventstream.StringValue("{{ $.eventName }}"),
   244  			},
   245  		},
   246  		{{- template "set event message payload" Map "idx" $.idx "parentShape" $.parentShape }}
   247  	},
   248  {{- end }}
   249  
   250  {{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
   251  {{ define "set event message header" }}
   252  	{{- if (and ($.memRef.IsEventPayload) (eq $.memRef.Shape.Type "blob")) }}
   253  		{
   254  			Name: ":content-type",
   255  			Value: eventstream.StringValue("application/octet-stream"),
   256  		},
   257  	{{- else if $.memRef.IsEventHeader }}
   258  		{
   259  			Name: "{{ $.memName }}",
   260  			{{- $shapeValueVar := printf "inputEvents[%d].(%s).%s" $.idx $.parentShape.GoType $.memName }}
   261  			Value: {{ EventHeaderValueForType $.memRef.Shape $shapeValueVar }},
   262  		},
   263  	{{- end }}
   264  {{- end }}
   265  
   266  {{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
   267  {{ define "set event message payload" }}
   268  	{{- $payloadMemName := $.parentShape.PayloadRefName }}
   269  	{{- if HasNonBlobPayloadMembers $.parentShape }}
   270  		Payload: eventstreamtest.MarshalEventPayload(payloadMarshaler, inputEvents[{{ $.idx }}]),
   271  	{{- else if $payloadMemName }}
   272  		{{- $shapeType := (index $.parentShape.MemberRefs $payloadMemName).Shape.Type }}
   273  		{{- if eq $shapeType "blob" }}
   274  			Payload: inputEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }},
   275  		{{- else if eq $shapeType "string" }}
   276  			Payload: []byte(*inputEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }}),
   277  		{{- end }}
   278  	{{- end }}
   279  {{- end }}
   280  `))