github.com/clubpay/ronykit/kit@v0.14.4-0.20240515065620-d0dace45cbc7/stub/stub.go (about)

     1  package stub
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"net/http"
     7  	"net/url"
     8  	"runtime"
     9  	"time"
    10  
    11  	"github.com/clubpay/ronykit/kit"
    12  	"github.com/clubpay/ronykit/kit/common"
    13  	"github.com/clubpay/ronykit/kit/utils/reflector"
    14  	"github.com/fasthttp/websocket"
    15  	"github.com/valyala/fasthttp"
    16  )
    17  
    18  var defaultConcurrency = 1024 * runtime.NumCPU()
    19  
    20  type Stub struct {
    21  	cfg config
    22  	r   *reflector.Reflector
    23  
    24  	httpC *fasthttp.Client
    25  }
    26  
    27  func New(hostPort string, opts ...Option) *Stub {
    28  	cfg := config{
    29  		name:         "RonyKIT Client",
    30  		hostPort:     hostPort,
    31  		readTimeout:  time.Minute * 5,
    32  		writeTimeout: time.Minute * 5,
    33  		dialTimeout:  time.Second * 45,
    34  		l:            common.NewNopLogger(),
    35  		codec:        kit.GetMessageCodec(),
    36  	}
    37  	for _, opt := range opts {
    38  		opt(&cfg)
    39  	}
    40  
    41  	httpC := &fasthttp.Client{
    42  		Name:                          cfg.name,
    43  		ReadTimeout:                   cfg.readTimeout,
    44  		WriteTimeout:                  cfg.writeTimeout,
    45  		DisableHeaderNamesNormalizing: true,
    46  		TLSConfig: &tls.Config{
    47  			InsecureSkipVerify: cfg.skipVerifyTLS, //nolint:gosec
    48  		},
    49  	}
    50  
    51  	if cfg.dialFunc != nil {
    52  		httpC.Dial = cfg.dialFunc
    53  	}
    54  
    55  	return &Stub{
    56  		cfg:   cfg,
    57  		r:     reflector.New(),
    58  		httpC: httpC,
    59  	}
    60  }
    61  
    62  func HTTP(rawURL string, opts ...Option) (*RESTCtx, error) {
    63  	u, err := url.ParseRequestURI(rawURL)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	switch u.Scheme {
    69  	default:
    70  		return nil, fmt.Errorf("unsupported scheme: %s", u.Scheme)
    71  	case "http":
    72  	case "https":
    73  		opts = append(opts, Secure())
    74  	}
    75  
    76  	s := New(u.Host, opts...).REST()
    77  	s.SetPath(u.Path)
    78  	for k, v := range u.Query() {
    79  		for _, vv := range v {
    80  			s.AppendQuery(k, vv)
    81  		}
    82  	}
    83  
    84  	return s, nil
    85  }
    86  
    87  func (s *Stub) REST(opt ...RESTOption) *RESTCtx {
    88  	ctx := &RESTCtx{
    89  		c:        s.httpC,
    90  		r:        s.r,
    91  		handlers: map[int]RESTResponseHandler{},
    92  		uri:      fasthttp.AcquireURI(),
    93  		args:     fasthttp.AcquireArgs(),
    94  		req:      fasthttp.AcquireRequest(),
    95  		res:      fasthttp.AcquireResponse(),
    96  		timeout:  s.cfg.readTimeout,
    97  		codec:    s.cfg.codec,
    98  	}
    99  
   100  	if s.cfg.secure {
   101  		ctx.uri.SetScheme("https")
   102  	} else {
   103  		ctx.uri.SetScheme("http")
   104  	}
   105  
   106  	ctx.cfg.tp = s.cfg.tp
   107  	ctx.uri.SetHost(s.cfg.hostPort)
   108  	ctx.DumpRequestTo(s.cfg.dumpReq)
   109  	ctx.DumpResponseTo(s.cfg.dumpRes)
   110  
   111  	for _, o := range opt {
   112  		o(&ctx.cfg)
   113  	}
   114  
   115  	return ctx
   116  }
   117  
   118  func (s *Stub) Websocket(opts ...WebsocketOption) *WebsocketCtx {
   119  	defaultProxy := http.ProxyFromEnvironment
   120  	if s.cfg.proxy != nil {
   121  		defaultProxy = func(req *http.Request) (*url.URL, error) {
   122  			return s.cfg.proxy.ProxyFunc()(req.URL)
   123  		}
   124  	}
   125  
   126  	defaultDialerBuilder := func() *websocket.Dialer {
   127  		return &websocket.Dialer{
   128  			Proxy:            defaultProxy,
   129  			HandshakeTimeout: s.cfg.dialTimeout,
   130  		}
   131  	}
   132  	ctx := &WebsocketCtx{
   133  		cfg: wsConfig{
   134  			autoReconnect:   true,
   135  			pingTime:        time.Second * 30,
   136  			dialTimeout:     s.cfg.dialTimeout,
   137  			writeTimeout:    s.cfg.writeTimeout,
   138  			ratelimitChan:   make(chan struct{}, defaultConcurrency),
   139  			rpcInFactory:    common.SimpleIncomingJSONRPC,
   140  			rpcOutFactory:   common.SimpleOutgoingJSONRPC,
   141  			dialerBuilder:   defaultDialerBuilder,
   142  			tracePropagator: s.cfg.tp,
   143  		},
   144  		r:       s.r,
   145  		l:       s.cfg.l,
   146  		pending: make(map[string]chan kit.IncomingRPCContainer, 1024),
   147  	}
   148  
   149  	for _, o := range opts {
   150  		o(&ctx.cfg)
   151  	}
   152  
   153  	if s.cfg.secure {
   154  		ctx.url = fmt.Sprintf("wss://%s", s.cfg.hostPort)
   155  	} else {
   156  		ctx.url = fmt.Sprintf("ws://%s", s.cfg.hostPort)
   157  	}
   158  
   159  	return ctx
   160  }