github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/eventstream/eventstreamtest/testing.go (about) 1 //go:build go1.15 2 // +build go1.15 3 4 package eventstreamtest 5 6 import ( 7 "bytes" 8 "context" 9 "fmt" 10 "io" 11 "net/http" 12 "net/http/httptest" 13 "reflect" 14 "strings" 15 "sync" 16 "testing" 17 "time" 18 19 "github.com/aavshr/aws-sdk-go/aws" 20 "github.com/aavshr/aws-sdk-go/aws/session" 21 "github.com/aavshr/aws-sdk-go/awstesting/unit" 22 "github.com/aavshr/aws-sdk-go/private/protocol" 23 "github.com/aavshr/aws-sdk-go/private/protocol/eventstream" 24 "golang.org/x/net/http2" 25 ) 26 27 const ( 28 errClientDisconnected = "client disconnected" 29 errStreamClosed = "http2: stream closed" 30 ) 31 32 // ServeEventStream provides serving EventStream messages from a HTTP server to 33 // the client. The events are sent sequentially to the client without delay. 34 type ServeEventStream struct { 35 T *testing.T 36 BiDirectional bool 37 38 Events []eventstream.Message 39 ClientEvents []eventstream.Message 40 41 ForceCloseAfter time.Duration 42 43 requestsIdx int 44 } 45 46 func (s ServeEventStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { 47 w.WriteHeader(http.StatusOK) 48 w.(http.Flusher).Flush() 49 50 if s.BiDirectional { 51 s.serveBiDirectionalStream(w, r) 52 } else { 53 s.serveReadOnlyStream(w, r) 54 } 55 } 56 57 func (s *ServeEventStream) serveReadOnlyStream(w http.ResponseWriter, r *http.Request) { 58 encoder := eventstream.NewEncoder(flushWriter{w}) 59 60 for _, event := range s.Events { 61 encoder.Encode(event) 62 } 63 } 64 65 func (s *ServeEventStream) serveBiDirectionalStream(w http.ResponseWriter, r *http.Request) { 66 var wg sync.WaitGroup 67 68 ctx := context.Background() 69 if s.ForceCloseAfter > 0 { 70 var cancelFunc func() 71 ctx, cancelFunc = context.WithTimeout(context.Background(), s.ForceCloseAfter) 72 defer cancelFunc() 73 } 74 75 var ( 76 err error 77 m sync.Mutex 78 ) 79 80 wg.Add(1) 81 go func() { 82 defer wg.Done() 83 readErr := s.readEvents(ctx, r) 84 if readErr != nil { 85 m.Lock() 86 if err == nil { 87 err = readErr 88 } 89 m.Unlock() 90 } 91 }() 92 93 writeErr := s.writeEvents(ctx, w) 94 if writeErr != nil { 95 m.Lock() 96 if err != nil { 97 err = writeErr 98 } 99 m.Unlock() 100 } 101 wg.Wait() 102 103 if err != nil && isError(err) { 104 s.T.Error(err.Error()) 105 } 106 } 107 108 func isError(err error) bool { 109 switch err.(type) { 110 case http2.StreamError: 111 return false 112 } 113 114 for _, s := range []string{errClientDisconnected, errStreamClosed} { 115 if strings.Contains(err.Error(), s) { 116 return false 117 } 118 } 119 120 return true 121 } 122 123 func (s ServeEventStream) readEvents(ctx context.Context, r *http.Request) error { 124 signBuffer := make([]byte, 1024) 125 messageBuffer := make([]byte, 1024) 126 decoder := eventstream.NewDecoder(r.Body) 127 128 for { 129 select { 130 case <-ctx.Done(): 131 return nil 132 default: 133 } 134 // unwrap signing envelope 135 signedMessage, err := decoder.Decode(signBuffer) 136 if err != nil { 137 if err == io.EOF { 138 break 139 } 140 return err 141 } 142 143 // empty payload is expected for the last signing message 144 if len(signedMessage.Payload) == 0 { 145 break 146 } 147 148 // get service event message from payload 149 msg, err := eventstream.Decode(bytes.NewReader(signedMessage.Payload), messageBuffer) 150 if err != nil { 151 if err == io.EOF { 152 break 153 } 154 return err 155 } 156 157 if len(s.ClientEvents) > 0 { 158 i := s.requestsIdx 159 s.requestsIdx++ 160 161 if e, a := s.ClientEvents[i], msg; !reflect.DeepEqual(e, a) { 162 return fmt.Errorf("expected %v, got %v", e, a) 163 } 164 } 165 } 166 167 return nil 168 } 169 170 func (s *ServeEventStream) writeEvents(ctx context.Context, w http.ResponseWriter) error { 171 encoder := eventstream.NewEncoder(flushWriter{w}) 172 173 var event eventstream.Message 174 pendingEvents := s.Events 175 176 for len(pendingEvents) > 0 { 177 event, pendingEvents = pendingEvents[0], pendingEvents[1:] 178 select { 179 case <-ctx.Done(): 180 return nil 181 default: 182 err := encoder.Encode(event) 183 if err != nil { 184 if err == io.EOF { 185 return nil 186 } 187 return fmt.Errorf("expected no error encoding event, got %v", err) 188 } 189 } 190 } 191 192 return nil 193 } 194 195 // SetupEventStreamSession creates a HTTP server SDK session for communicating 196 // with that server to be used for EventStream APIs. If HTTP/2 is enabled the 197 // server/client will only attempt to use HTTP/2. 198 func SetupEventStreamSession( 199 t *testing.T, handler http.Handler, h2 bool, 200 ) (sess *session.Session, cleanupFn func(), err error) { 201 server := httptest.NewUnstartedServer(handler) 202 203 client := setupServer(server, h2) 204 205 cleanupFn = func() { 206 server.Close() 207 } 208 209 sess, err = session.NewSession(unit.Session.Config, &aws.Config{ 210 Endpoint: &server.URL, 211 DisableParamValidation: aws.Bool(true), 212 HTTPClient: client, 213 // LogLevel: aws.LogLevel(aws.LogDebugWithEventStreamBody), 214 }) 215 if err != nil { 216 return nil, nil, err 217 } 218 219 return sess, cleanupFn, nil 220 } 221 222 type flushWriter struct { 223 w io.Writer 224 } 225 226 func (fw flushWriter) Write(p []byte) (n int, err error) { 227 n, err = fw.w.Write(p) 228 if f, ok := fw.w.(http.Flusher); ok { 229 f.Flush() 230 } 231 return 232 } 233 234 // MarshalEventPayload marshals a SDK API shape into its associated wire 235 // protocol payload. 236 func MarshalEventPayload( 237 payloadMarshaler protocol.PayloadMarshaler, 238 v interface{}, 239 ) []byte { 240 var w bytes.Buffer 241 err := payloadMarshaler.MarshalPayload(&w, v) 242 if err != nil { 243 panic(fmt.Sprintf("failed to marshal event %T, %v, %v", v, v, err)) 244 } 245 246 return w.Bytes() 247 } 248 249 // Prevent circular dependencies on eventstreamapi redefine these here. 250 const ( 251 messageTypeHeader = `:message-type` // Identifies type of message. 252 eventMessageType = `event` 253 exceptionMessageType = `exception` 254 ) 255 256 // EventMessageTypeHeader is an event message type header for specifying an 257 // event is an message type. 258 var EventMessageTypeHeader = eventstream.Header{ 259 Name: messageTypeHeader, 260 Value: eventstream.StringValue(eventMessageType), 261 } 262 263 // EventExceptionTypeHeader is an event exception type header for specifying an 264 // event is an exception type. 265 var EventExceptionTypeHeader = eventstream.Header{ 266 Name: messageTypeHeader, 267 Value: eventstream.StringValue(exceptionMessageType), 268 }