github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/eventstream/eventstreamtest/testing.go (about)

     1  //go:build go1.15
     2  // +build go1.15
     3  
     4  package eventstreamtest
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"reflect"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/aavshr/aws-sdk-go/aws"
    20  	"github.com/aavshr/aws-sdk-go/aws/session"
    21  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    22  	"github.com/aavshr/aws-sdk-go/private/protocol"
    23  	"github.com/aavshr/aws-sdk-go/private/protocol/eventstream"
    24  	"golang.org/x/net/http2"
    25  )
    26  
    27  const (
    28  	errClientDisconnected = "client disconnected"
    29  	errStreamClosed       = "http2: stream closed"
    30  )
    31  
    32  // ServeEventStream provides serving EventStream messages from a HTTP server to
    33  // the client. The events are sent sequentially to the client without delay.
    34  type ServeEventStream struct {
    35  	T             *testing.T
    36  	BiDirectional bool
    37  
    38  	Events       []eventstream.Message
    39  	ClientEvents []eventstream.Message
    40  
    41  	ForceCloseAfter time.Duration
    42  
    43  	requestsIdx int
    44  }
    45  
    46  func (s ServeEventStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    47  	w.WriteHeader(http.StatusOK)
    48  	w.(http.Flusher).Flush()
    49  
    50  	if s.BiDirectional {
    51  		s.serveBiDirectionalStream(w, r)
    52  	} else {
    53  		s.serveReadOnlyStream(w, r)
    54  	}
    55  }
    56  
    57  func (s *ServeEventStream) serveReadOnlyStream(w http.ResponseWriter, r *http.Request) {
    58  	encoder := eventstream.NewEncoder(flushWriter{w})
    59  
    60  	for _, event := range s.Events {
    61  		encoder.Encode(event)
    62  	}
    63  }
    64  
    65  func (s *ServeEventStream) serveBiDirectionalStream(w http.ResponseWriter, r *http.Request) {
    66  	var wg sync.WaitGroup
    67  
    68  	ctx := context.Background()
    69  	if s.ForceCloseAfter > 0 {
    70  		var cancelFunc func()
    71  		ctx, cancelFunc = context.WithTimeout(context.Background(), s.ForceCloseAfter)
    72  		defer cancelFunc()
    73  	}
    74  
    75  	var (
    76  		err error
    77  		m   sync.Mutex
    78  	)
    79  
    80  	wg.Add(1)
    81  	go func() {
    82  		defer wg.Done()
    83  		readErr := s.readEvents(ctx, r)
    84  		if readErr != nil {
    85  			m.Lock()
    86  			if err == nil {
    87  				err = readErr
    88  			}
    89  			m.Unlock()
    90  		}
    91  	}()
    92  
    93  	writeErr := s.writeEvents(ctx, w)
    94  	if writeErr != nil {
    95  		m.Lock()
    96  		if err != nil {
    97  			err = writeErr
    98  		}
    99  		m.Unlock()
   100  	}
   101  	wg.Wait()
   102  
   103  	if err != nil && isError(err) {
   104  		s.T.Error(err.Error())
   105  	}
   106  }
   107  
   108  func isError(err error) bool {
   109  	switch err.(type) {
   110  	case http2.StreamError:
   111  		return false
   112  	}
   113  
   114  	for _, s := range []string{errClientDisconnected, errStreamClosed} {
   115  		if strings.Contains(err.Error(), s) {
   116  			return false
   117  		}
   118  	}
   119  
   120  	return true
   121  }
   122  
   123  func (s ServeEventStream) readEvents(ctx context.Context, r *http.Request) error {
   124  	signBuffer := make([]byte, 1024)
   125  	messageBuffer := make([]byte, 1024)
   126  	decoder := eventstream.NewDecoder(r.Body)
   127  
   128  	for {
   129  		select {
   130  		case <-ctx.Done():
   131  			return nil
   132  		default:
   133  		}
   134  		// unwrap signing envelope
   135  		signedMessage, err := decoder.Decode(signBuffer)
   136  		if err != nil {
   137  			if err == io.EOF {
   138  				break
   139  			}
   140  			return err
   141  		}
   142  
   143  		// empty payload is expected for the last signing message
   144  		if len(signedMessage.Payload) == 0 {
   145  			break
   146  		}
   147  
   148  		// get service event message from payload
   149  		msg, err := eventstream.Decode(bytes.NewReader(signedMessage.Payload), messageBuffer)
   150  		if err != nil {
   151  			if err == io.EOF {
   152  				break
   153  			}
   154  			return err
   155  		}
   156  
   157  		if len(s.ClientEvents) > 0 {
   158  			i := s.requestsIdx
   159  			s.requestsIdx++
   160  
   161  			if e, a := s.ClientEvents[i], msg; !reflect.DeepEqual(e, a) {
   162  				return fmt.Errorf("expected %v, got %v", e, a)
   163  			}
   164  		}
   165  	}
   166  
   167  	return nil
   168  }
   169  
   170  func (s *ServeEventStream) writeEvents(ctx context.Context, w http.ResponseWriter) error {
   171  	encoder := eventstream.NewEncoder(flushWriter{w})
   172  
   173  	var event eventstream.Message
   174  	pendingEvents := s.Events
   175  
   176  	for len(pendingEvents) > 0 {
   177  		event, pendingEvents = pendingEvents[0], pendingEvents[1:]
   178  		select {
   179  		case <-ctx.Done():
   180  			return nil
   181  		default:
   182  			err := encoder.Encode(event)
   183  			if err != nil {
   184  				if err == io.EOF {
   185  					return nil
   186  				}
   187  				return fmt.Errorf("expected no error encoding event, got %v", err)
   188  			}
   189  		}
   190  	}
   191  
   192  	return nil
   193  }
   194  
   195  // SetupEventStreamSession creates a HTTP server SDK session for communicating
   196  // with that server to be used for EventStream APIs. If HTTP/2 is enabled the
   197  // server/client will only attempt to use HTTP/2.
   198  func SetupEventStreamSession(
   199  	t *testing.T, handler http.Handler, h2 bool,
   200  ) (sess *session.Session, cleanupFn func(), err error) {
   201  	server := httptest.NewUnstartedServer(handler)
   202  
   203  	client := setupServer(server, h2)
   204  
   205  	cleanupFn = func() {
   206  		server.Close()
   207  	}
   208  
   209  	sess, err = session.NewSession(unit.Session.Config, &aws.Config{
   210  		Endpoint:               &server.URL,
   211  		DisableParamValidation: aws.Bool(true),
   212  		HTTPClient:             client,
   213  		//		LogLevel:               aws.LogLevel(aws.LogDebugWithEventStreamBody),
   214  	})
   215  	if err != nil {
   216  		return nil, nil, err
   217  	}
   218  
   219  	return sess, cleanupFn, nil
   220  }
   221  
   222  type flushWriter struct {
   223  	w io.Writer
   224  }
   225  
   226  func (fw flushWriter) Write(p []byte) (n int, err error) {
   227  	n, err = fw.w.Write(p)
   228  	if f, ok := fw.w.(http.Flusher); ok {
   229  		f.Flush()
   230  	}
   231  	return
   232  }
   233  
   234  // MarshalEventPayload marshals a SDK API shape into its associated wire
   235  // protocol payload.
   236  func MarshalEventPayload(
   237  	payloadMarshaler protocol.PayloadMarshaler,
   238  	v interface{},
   239  ) []byte {
   240  	var w bytes.Buffer
   241  	err := payloadMarshaler.MarshalPayload(&w, v)
   242  	if err != nil {
   243  		panic(fmt.Sprintf("failed to marshal event %T, %v, %v", v, v, err))
   244  	}
   245  
   246  	return w.Bytes()
   247  }
   248  
   249  // Prevent circular dependencies on eventstreamapi redefine these here.
   250  const (
   251  	messageTypeHeader    = `:message-type` // Identifies type of message.
   252  	eventMessageType     = `event`
   253  	exceptionMessageType = `exception`
   254  )
   255  
   256  // EventMessageTypeHeader is an event message type header for specifying an
   257  // event is an message type.
   258  var EventMessageTypeHeader = eventstream.Header{
   259  	Name:  messageTypeHeader,
   260  	Value: eventstream.StringValue(eventMessageType),
   261  }
   262  
   263  // EventExceptionTypeHeader is an event exception type header for specifying an
   264  // event is an exception type.
   265  var EventExceptionTypeHeader = eventstream.Header{
   266  	Name:  messageTypeHeader,
   267  	Value: eventstream.StringValue(exceptionMessageType),
   268  }