go.uber.org/yarpc@v1.72.1/api/transport/transporttest/messagepipe.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package transporttest
    22  
    23  import (
    24  	"context"
    25  	"io"
    26  
    27  	"go.uber.org/yarpc/api/transport"
    28  )
    29  
    30  type pipeOptions struct{}
    31  
    32  // MessagePipeOption is an option for the MessagePipe constructor.
    33  type MessagePipeOption interface {
    34  	apply(*pipeOptions)
    35  }
    36  
    37  type messageResult struct {
    38  	msg *transport.StreamMessage
    39  	err error
    40  }
    41  
    42  // MessagePipe creates an in-memory client and server message stream pair.
    43  //
    44  // The third return value is a function that the server side uses to transport
    45  // the end of stream error, if not nil.
    46  // Calling the finish function with nil is valid.
    47  //
    48  //   finish(streamHandler.HandleStream(serverStream))
    49  //
    50  func MessagePipe(ctx context.Context, req *transport.StreamRequest, _ ...MessagePipeOption) (*transport.ClientStream, *transport.ServerStream, func(error), error) {
    51  	c2s := make(chan messageResult)
    52  	s2c := make(chan messageResult)
    53  	clientClosed := make(chan struct{})
    54  	serverClosed := make(chan struct{})
    55  	client, err := transport.NewClientStream(&stream{
    56  		ctx:        ctx,
    57  		req:        req,
    58  		send:       c2s,
    59  		recv:       s2c,
    60  		sendClosed: clientClosed,
    61  		recvClosed: serverClosed,
    62  	})
    63  	if err != nil {
    64  		return nil, nil, nil, err
    65  	}
    66  	server, err := transport.NewServerStream(&stream{
    67  		ctx:        ctx,
    68  		req:        req,
    69  		send:       s2c,
    70  		recv:       c2s,
    71  		sendClosed: serverClosed,
    72  		recvClosed: clientClosed,
    73  	})
    74  	if err != nil {
    75  		return nil, nil, nil, err
    76  	}
    77  
    78  	finish := func(err error) {
    79  		if err == nil {
    80  			return
    81  		}
    82  		// If HandleStream returns an error, we realize this
    83  		// by sending that error through the server to client
    84  		// channel, so it can be picked up by the client's next
    85  		// ReceiveMessage/Recv call.
    86  		select {
    87  		case <-clientClosed:
    88  		case <-serverClosed:
    89  		case <-ctx.Done():
    90  		case s2c <- messageResult{err: err}:
    91  			close(serverClosed)
    92  		}
    93  	}
    94  
    95  	return client, server, finish, nil
    96  }
    97  
    98  type stream struct {
    99  	req        *transport.StreamRequest
   100  	ctx        context.Context
   101  	send       chan<- messageResult
   102  	recv       <-chan messageResult
   103  	sendClosed chan struct{}
   104  	recvClosed chan struct{}
   105  }
   106  
   107  func (s *stream) Context() context.Context {
   108  	return s.ctx
   109  }
   110  
   111  func (s *stream) Request() *transport.StreamRequest {
   112  	return s.req
   113  }
   114  
   115  func (s *stream) SendMessage(ctx context.Context, msg *transport.StreamMessage) error {
   116  	select {
   117  	case <-s.sendClosed:
   118  		return io.EOF
   119  	case <-s.recvClosed:
   120  		return io.EOF
   121  	case <-s.ctx.Done():
   122  		return s.ctx.Err()
   123  	case <-ctx.Done():
   124  		return ctx.Err()
   125  	case s.send <- messageResult{msg: msg}:
   126  		return nil
   127  	}
   128  }
   129  
   130  func (s *stream) ReceiveMessage(ctx context.Context) (*transport.StreamMessage, error) {
   131  	select {
   132  	case <-s.sendClosed:
   133  		return nil, io.EOF
   134  	case <-s.recvClosed:
   135  		return nil, io.EOF
   136  	case <-s.ctx.Done():
   137  		return nil, s.ctx.Err()
   138  	case <-ctx.Done():
   139  		return nil, ctx.Err()
   140  	case res := <-s.recv:
   141  		return res.msg, res.err
   142  	}
   143  }
   144  
   145  func (s *stream) Close(_ context.Context) error {
   146  	close(s.sendClosed)
   147  	return nil
   148  }