tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/rpc/client.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"tractor.dev/toolkit-go/duplex/codec"
     8  	"tractor.dev/toolkit-go/duplex/mux"
     9  )
    10  
    11  // RemoteError is an error that has been returned from
    12  // the remote side of the RPC connection.
    13  type RemoteError string
    14  
    15  func (e RemoteError) Error() string {
    16  	return fmt.Sprintf("remote: %s", string(e))
    17  }
    18  
    19  // Client wraps a session and codec to make RPC calls over the session.
    20  type Client struct {
    21  	mux.Session
    22  	codec codec.Codec
    23  }
    24  
    25  // NewClient takes a session and codec to make a client for making RPC calls.
    26  func NewClient(session mux.Session, codec codec.Codec) *Client {
    27  	return &Client{
    28  		Session: session,
    29  		codec:   codec,
    30  	}
    31  }
    32  
    33  // Call makes synchronous calls to the remote selector passing args and putting the reply
    34  // value in reply. Both args and reply can be nil. Args can be a channel of interface{}
    35  // values for asynchronously streaming multiple values from another goroutine, however
    36  // the call will still block until a response is sent. If there is an error making the call
    37  // an error is returned, and if an error is returned by the remote handler a RemoteError
    38  // is returned.
    39  //
    40  // A Response value is also returned for advanced operations. For example, you can check
    41  // if the call is continued, meaning the underlying channel will be kept open for either
    42  // streaming back more results or using the channel as a full duplex byte stream.
    43  func (c *Client) Call(ctx context.Context, selector string, args any, reply ...any) (*Response, error) {
    44  	ch, err := c.Session.Open(ctx)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	// If the context is cancelled before the call completes, call Close() to
    49  	// abort the current operation.
    50  	done := make(chan struct{})
    51  	defer close(done)
    52  	go func() {
    53  		select {
    54  		case <-ctx.Done():
    55  			ch.Close()
    56  		case <-done:
    57  		}
    58  	}()
    59  	resp, err := call(ctx, ch, c.codec, selector, args, reply...)
    60  	if ctxErr := ctx.Err(); ctxErr != nil {
    61  		return resp, ctxErr
    62  	}
    63  	return resp, err
    64  }
    65  
    66  func call(ctx context.Context, ch mux.Channel, cd codec.Codec, selector string, args any, reply ...any) (*Response, error) {
    67  	framer := &FrameCodec{Codec: cd}
    68  	enc := framer.Encoder(ch)
    69  	dec := framer.Decoder(ch)
    70  
    71  	// request
    72  	err := enc.Encode(CallHeader{
    73  		S: selector,
    74  	})
    75  	if err != nil {
    76  		ch.Close()
    77  		return nil, err
    78  	}
    79  
    80  	argCh, isChan := args.(chan interface{})
    81  	switch {
    82  	case isChan:
    83  		for arg := range argCh {
    84  			if err := enc.Encode(arg); err != nil {
    85  				ch.Close()
    86  				return nil, err
    87  			}
    88  		}
    89  	default:
    90  		if err := enc.Encode(args); err != nil {
    91  			ch.Close()
    92  			return nil, err
    93  		}
    94  	}
    95  
    96  	// response
    97  	var header ResponseHeader
    98  	err = dec.Decode(&header)
    99  	if err != nil {
   100  		ch.Close()
   101  		return nil, err
   102  	}
   103  
   104  	if !header.C {
   105  		defer ch.Close()
   106  	}
   107  
   108  	resp := &Response{
   109  		ResponseHeader: header,
   110  		Channel:        ch,
   111  		codec:          framer,
   112  	}
   113  	if len(reply) == 1 {
   114  		resp.Value = reply[0]
   115  	} else if len(reply) > 1 {
   116  		resp.Value = reply
   117  	}
   118  	if resp.Err() != nil {
   119  		return resp, RemoteError(resp.Err().Error())
   120  	}
   121  
   122  	if resp.Value == nil {
   123  		// read into throwaway buffer
   124  		var buf []byte
   125  		dec.Decode(&buf)
   126  	} else {
   127  		for _, r := range reply {
   128  			if err := dec.Decode(r); err != nil {
   129  				return resp, err
   130  			}
   131  		}
   132  	}
   133  
   134  	return resp, nil
   135  }