github.com/aavshr/aws-sdk-go@v1.41.3/private/model/api/eventstream_tmpl_writertests.go (about) 1 //go:build codegen 2 // +build codegen 3 4 package api 5 6 import ( 7 "text/template" 8 ) 9 10 var eventStreamWriterTestTmpl = template.Must( 11 template.New("eventStreamWriterTestTmpl").Funcs(template.FuncMap{ 12 "ValueForType": valueForType, 13 "HasNonBlobPayloadMembers": eventHasNonBlobPayloadMembers, 14 "EventHeaderValueForType": setEventHeaderValueForType, 15 "Map": templateMap, 16 "OptionalAddInt": func(do bool, a, b int) int { 17 if !do { 18 return a 19 } 20 return a + b 21 }, 22 "HasNonEventStreamMember": func(s *Shape) bool { 23 for _, ref := range s.MemberRefs { 24 if !ref.Shape.IsEventStream { 25 return true 26 } 27 } 28 return false 29 }, 30 }).Parse(` 31 {{ range $opName, $op := $.Operations }} 32 {{ if $op.EventStreamAPI }} 33 {{ if $op.EventStreamAPI.InputStream }} 34 {{ template "event stream inputStream tests" $op.EventStreamAPI }} 35 {{ end }} 36 {{ end }} 37 {{ end }} 38 39 {{ define "event stream inputStream tests" }} 40 func Test{{ $.Operation.ExportedName }}_Write(t *testing.T) { 41 clientEvents, expectedClientEvents := mock{{ $.Operation.ExportedName }}WriteEvents() 42 43 sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t, 44 &eventstreamtest.ServeEventStream{ 45 T: t, 46 ClientEvents: expectedClientEvents, 47 BiDirectional: true, 48 }, 49 true) 50 defer cleanupFn() 51 52 svc := New(sess) 53 resp, err := svc.{{ $.Operation.ExportedName }}(nil) 54 if err != nil { 55 t.Fatalf("expect no error, got %v", err) 56 } 57 58 stream := resp.GetStream() 59 60 for _, event := range clientEvents { 61 err = stream.Send(context.Background(), event) 62 if err != nil { 63 t.Fatalf("expect no error, got %v", err) 64 } 65 } 66 67 if err := stream.Close(); err != nil { 68 t.Errorf("expect no error, got %v", err) 69 } 70 } 71 72 func Test{{ $.Operation.ExportedName }}_WriteClose(t *testing.T) { 73 sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t, 74 eventstreamtest.ServeEventStream{T: t, BiDirectional: true}, 75 true, 76 ) 77 if err != nil { 78 t.Fatalf("expect no error, %v", err) 79 } 80 defer cleanupFn() 81 82 svc := New(sess) 83 resp, err := svc.{{ $.Operation.ExportedName }}(nil) 84 if err != nil { 85 t.Fatalf("expect no error got, %v", err) 86 } 87 88 // Assert calling Err before close does not close the stream. 89 resp.GetStream().Err() 90 {{ $eventShape := index $.InputStream.Events 0 }} 91 err = resp.GetStream().Send(context.Background(), &{{ $eventShape.Shape.ShapeName }}{}) 92 if err != nil { 93 t.Fatalf("expect no error, got %v", err) 94 } 95 96 resp.GetStream().Close() 97 98 if err := resp.GetStream().Err(); err != nil { 99 t.Errorf("expect no error, %v", err) 100 } 101 } 102 103 func Test{{ $.Operation.ExportedName }}_WriteError(t *testing.T) { 104 sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t, 105 eventstreamtest.ServeEventStream{ 106 T: t, 107 BiDirectional: true, 108 ForceCloseAfter: time.Millisecond * 500, 109 }, 110 true, 111 ) 112 if err != nil { 113 t.Fatalf("expect no error, %v", err) 114 } 115 defer cleanupFn() 116 117 svc := New(sess) 118 resp, err := svc.{{ $.Operation.ExportedName }}(nil) 119 if err != nil { 120 t.Fatalf("expect no error got, %v", err) 121 } 122 123 defer resp.GetStream().Close() 124 {{ $eventShape := index $.InputStream.Events 0 }} 125 for { 126 err = resp.GetStream().Send(context.Background(), &{{ $eventShape.Shape.ShapeName }}{}) 127 if err != nil { 128 if strings.Contains("unable to send event", err.Error()) { 129 t.Errorf("expected stream closed error, got %v", err) 130 } 131 break 132 } 133 } 134 } 135 136 func Test{{ $.Operation.ExportedName }}_ReadWrite(t *testing.T) { 137 expectedServiceEvents, serviceEvents := mock{{ $.Operation.ExportedName }}ReadEvents() 138 clientEvents, expectedClientEvents := mock{{ $.Operation.ExportedName }}WriteEvents() 139 140 sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t, 141 &eventstreamtest.ServeEventStream{ 142 T: t, 143 ClientEvents: expectedClientEvents, 144 Events: serviceEvents, 145 BiDirectional: true, 146 }, 147 true) 148 defer cleanupFn() 149 150 svc := New(sess) 151 resp, err := svc.{{ $.Operation.ExportedName }}(nil) 152 if err != nil { 153 t.Fatalf("expect no error, got %v", err) 154 } 155 156 stream := resp.GetStream() 157 defer stream.Close() 158 159 var wg sync.WaitGroup 160 161 wg.Add(1) 162 go func() { 163 defer wg.Done() 164 var i int 165 for event := range resp.GetStream().Events() { 166 if event == nil { 167 t.Errorf("%d, expect event, got nil", i) 168 } 169 if e, a := expectedServiceEvents[i], event; !reflect.DeepEqual(e, a) { 170 t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a) 171 } 172 i++ 173 } 174 }() 175 176 for _, event := range clientEvents { 177 err = stream.Send(context.Background(), event) 178 if err != nil { 179 t.Errorf("expect no error, got %v", err) 180 } 181 } 182 183 resp.GetStream().Close() 184 185 wg.Wait() 186 187 if err := resp.GetStream().Err(); err != nil { 188 t.Errorf("expect no error, %v", err) 189 } 190 } 191 192 func mock{{ $.Operation.ExportedName }}WriteEvents() ( 193 []{{ $.InputStream.Name }}Event, 194 []eventstream.Message, 195 ) { 196 inputEvents := []{{ $.InputStream.Name }}Event { 197 {{- if eq $.Operation.API.Metadata.Protocol "json" }} 198 {{- template "set event type" $.Operation.InputRef.Shape }} 199 {{- end }} 200 {{- range $_, $event := $.InputStream.Events }} 201 {{- template "set event type" $event.Shape }} 202 {{- end }} 203 } 204 205 var marshalers request.HandlerList 206 marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler) 207 payloadMarshaler := protocol.HandlerPayloadMarshal{ 208 Marshalers: marshalers, 209 } 210 _ = payloadMarshaler 211 212 eventMsgs := []eventstream.Message{ 213 {{- range $idx, $event := $.InputStream.Events }} 214 {{- template "set event message" Map "idx" $idx "parentShape" $event.Shape "eventName" $event.Name }} 215 {{- end }} 216 } 217 218 return inputEvents, eventMsgs 219 } 220 {{ end }} 221 222 {{/* Params: *Shape */}} 223 {{ define "set event type" }} 224 &{{ $.ShapeName }}{ 225 {{- range $memName, $memRef := $.MemberRefs }} 226 {{- if not $memRef.Shape.IsEventStream }} 227 {{ $memName }}: {{ ValueForType $memRef.Shape nil }}, 228 {{- end }} 229 {{- end }} 230 }, 231 {{- end }} 232 233 {{/* Params: idx:int, parentShape:*Shape, eventName:string */}} 234 {{ define "set event message" }} 235 { 236 Headers: eventstream.Headers{ 237 eventstreamtest.EventMessageTypeHeader, 238 {{- range $memName, $memRef := $.parentShape.MemberRefs }} 239 {{- template "set event message header" Map "idx" $.idx "parentShape" $.parentShape "memName" $memName "memRef" $memRef }} 240 {{- end }} 241 { 242 Name: eventstreamapi.EventTypeHeader, 243 Value: eventstream.StringValue("{{ $.eventName }}"), 244 }, 245 }, 246 {{- template "set event message payload" Map "idx" $.idx "parentShape" $.parentShape }} 247 }, 248 {{- end }} 249 250 {{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}} 251 {{ define "set event message header" }} 252 {{- if (and ($.memRef.IsEventPayload) (eq $.memRef.Shape.Type "blob")) }} 253 { 254 Name: ":content-type", 255 Value: eventstream.StringValue("application/octet-stream"), 256 }, 257 {{- else if $.memRef.IsEventHeader }} 258 { 259 Name: "{{ $.memName }}", 260 {{- $shapeValueVar := printf "inputEvents[%d].(%s).%s" $.idx $.parentShape.GoType $.memName }} 261 Value: {{ EventHeaderValueForType $.memRef.Shape $shapeValueVar }}, 262 }, 263 {{- end }} 264 {{- end }} 265 266 {{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}} 267 {{ define "set event message payload" }} 268 {{- $payloadMemName := $.parentShape.PayloadRefName }} 269 {{- if HasNonBlobPayloadMembers $.parentShape }} 270 Payload: eventstreamtest.MarshalEventPayload(payloadMarshaler, inputEvents[{{ $.idx }}]), 271 {{- else if $payloadMemName }} 272 {{- $shapeType := (index $.parentShape.MemberRefs $payloadMemName).Shape.Type }} 273 {{- if eq $shapeType "blob" }} 274 Payload: inputEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }}, 275 {{- else if eq $shapeType "string" }} 276 Payload: []byte(*inputEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }}), 277 {{- end }} 278 {{- end }} 279 {{- end }} 280 `))