github.com/anycable/anycable-go@v1.5.1/rpc/http.go (about)

     1  package rpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"time"
    13  
    14  	"github.com/anycable/anycable-go/logger"
    15  	pb "github.com/anycable/anycable-go/protos"
    16  	"github.com/anycable/anycable-go/utils"
    17  	"github.com/sony/gobreaker"
    18  	"google.golang.org/grpc/codes"
    19  	"google.golang.org/grpc/metadata"
    20  	"google.golang.org/grpc/status"
    21  )
    22  
    23  type httpClientHelper struct {
    24  	service *HTTPService
    25  }
    26  
    27  func NewHTTPClientHelper(s *HTTPService) *httpClientHelper {
    28  	return &httpClientHelper{service: s}
    29  }
    30  
    31  func (h *httpClientHelper) Ready() error {
    32  	cbState := h.service.cb.State()
    33  
    34  	if cbState == gobreaker.StateOpen {
    35  		return errors.New("http rpc is temporarily unavailable")
    36  	}
    37  
    38  	return nil
    39  }
    40  
    41  func (h *httpClientHelper) SupportsActiveConns() bool {
    42  	return false
    43  }
    44  
    45  func (h *httpClientHelper) ActiveConns() int {
    46  	return 0
    47  }
    48  
    49  func (h *httpClientHelper) Close() {
    50  	h.service.client.CloseIdleConnections()
    51  }
    52  
    53  type HTTPService struct {
    54  	conf    *Config
    55  	client  *http.Client
    56  	baseURL *url.URL
    57  
    58  	cb *gobreaker.TwoStepCircuitBreaker
    59  }
    60  
    61  func NewHTTPDialer(c *Config) (Dialer, error) {
    62  	service, err := NewHTTPService(c)
    63  
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	helper := NewHTTPClientHelper(service)
    69  
    70  	return NewInprocessServiceDialer(service, helper), nil
    71  }
    72  
    73  func NewHTTPService(c *Config) (*HTTPService, error) {
    74  	tlsConfig, error := c.TLSConfig()
    75  	if error != nil {
    76  		return nil, error
    77  	}
    78  
    79  	client := &http.Client{
    80  		Transport: &http.Transport{TLSClientConfig: tlsConfig},
    81  	}
    82  
    83  	baseURL, err := url.Parse(c.Host)
    84  
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	cb := gobreaker.NewTwoStepCircuitBreaker(gobreaker.Settings{
    90  		Name:        "httrpc",
    91  		MaxRequests: 5,
    92  		Interval:    10 * time.Second,
    93  		Timeout:     5 * time.Second,
    94  		ReadyToTrip: func(counts gobreaker.Counts) bool {
    95  			failureRatio := float64(counts.TotalFailures) / float64(counts.Requests)
    96  			return counts.Requests >= 10 && failureRatio >= 0.8
    97  		},
    98  	})
    99  
   100  	return &HTTPService{conf: c, client: client, baseURL: baseURL, cb: cb}, nil
   101  }
   102  
   103  func (s *HTTPService) Connect(ctx context.Context, r *pb.ConnectionRequest) (*pb.ConnectionResponse, error) {
   104  	rawResponse, err := s.performRequest(ctx, "connect", utils.ToJSON(r))
   105  
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	var response pb.ConnectionResponse
   111  
   112  	err = json.Unmarshal(rawResponse, &response)
   113  
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	return &response, nil
   119  }
   120  
   121  func (s *HTTPService) Disconnect(ctx context.Context, r *pb.DisconnectRequest) (*pb.DisconnectResponse, error) {
   122  	rawResponse, err := s.performRequest(ctx, "disconnect", utils.ToJSON(r))
   123  
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	var response pb.DisconnectResponse
   129  
   130  	err = json.Unmarshal(rawResponse, &response)
   131  
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	return &response, nil
   137  }
   138  
   139  func (s *HTTPService) Command(ctx context.Context, r *pb.CommandMessage) (*pb.CommandResponse, error) {
   140  	rawResponse, err := s.performRequest(ctx, "command", utils.ToJSON(r))
   141  
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  
   146  	var response pb.CommandResponse
   147  
   148  	err = json.Unmarshal(rawResponse, &response)
   149  
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	return &response, nil
   155  }
   156  
   157  func (s *HTTPService) performRequest(ctx context.Context, path string, payload []byte) ([]byte, error) {
   158  	cbCallback, err := s.cb.Allow()
   159  
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	url := s.baseURL.JoinPath(path).String()
   165  
   166  	// We use timeouts to detect request queueing at the HTTP RPC side and report ResourceExhausted errors
   167  	// (so adaptive concurrency control can be applied)
   168  	ctx, cancel := context.WithTimeout(ctx, time.Duration(s.conf.RequestTimeout)*time.Millisecond)
   169  	defer cancel()
   170  
   171  	req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	req.Header.Set("Content-Type", "application/json")
   177  
   178  	if s.conf.Secret != "" {
   179  		req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.conf.Secret))
   180  	}
   181  
   182  	if md, ok := metadata.FromIncomingContext(ctx); ok {
   183  		// Set headers from metadata
   184  		for k, v := range md {
   185  			req.Header.Set(fmt.Sprintf("x-anycable-meta-%s", k), v[0])
   186  		}
   187  	}
   188  
   189  	res, err := s.client.Do(req)
   190  
   191  	if err != nil {
   192  		if ctx.Err() != nil {
   193  			return nil, status.Error(codes.DeadlineExceeded, "request timeout")
   194  		}
   195  
   196  		cbCallback(false)
   197  		return nil, status.Error(codes.Unavailable, err.Error())
   198  	}
   199  
   200  	cbCallback(true)
   201  
   202  	defer res.Body.Close()
   203  
   204  	if res.StatusCode == http.StatusUnauthorized {
   205  		return nil, status.Error(codes.Unauthenticated, "http returned 401")
   206  	}
   207  
   208  	if res.StatusCode == http.StatusBadRequest || res.StatusCode == http.StatusUnprocessableEntity {
   209  		reason, rerr := io.ReadAll(res.Body)
   210  		if rerr != nil {
   211  			return nil, status.Error(codes.InvalidArgument, "unprocessable entity")
   212  		}
   213  
   214  		return nil, status.Error(codes.InvalidArgument, logger.CompactValue(reason).String())
   215  	}
   216  
   217  	if res.StatusCode != http.StatusOK {
   218  		reason, rerr := io.ReadAll(res.Body)
   219  		if rerr != nil {
   220  			return nil, status.Error(codes.Unknown, "internal error")
   221  		}
   222  
   223  		return nil, status.Error(codes.Unknown, logger.CompactValue(reason).String())
   224  	}
   225  
   226  	// Finally, the response is successful, let's read the body
   227  	rawRequest, err := io.ReadAll(res.Body)
   228  
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  
   233  	return rawRequest, nil
   234  }