github.com/anycable/anycable-go@v1.5.1/rpc/http.go (about) 1 package rpc 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "io" 10 "net/http" 11 "net/url" 12 "time" 13 14 "github.com/anycable/anycable-go/logger" 15 pb "github.com/anycable/anycable-go/protos" 16 "github.com/anycable/anycable-go/utils" 17 "github.com/sony/gobreaker" 18 "google.golang.org/grpc/codes" 19 "google.golang.org/grpc/metadata" 20 "google.golang.org/grpc/status" 21 ) 22 23 type httpClientHelper struct { 24 service *HTTPService 25 } 26 27 func NewHTTPClientHelper(s *HTTPService) *httpClientHelper { 28 return &httpClientHelper{service: s} 29 } 30 31 func (h *httpClientHelper) Ready() error { 32 cbState := h.service.cb.State() 33 34 if cbState == gobreaker.StateOpen { 35 return errors.New("http rpc is temporarily unavailable") 36 } 37 38 return nil 39 } 40 41 func (h *httpClientHelper) SupportsActiveConns() bool { 42 return false 43 } 44 45 func (h *httpClientHelper) ActiveConns() int { 46 return 0 47 } 48 49 func (h *httpClientHelper) Close() { 50 h.service.client.CloseIdleConnections() 51 } 52 53 type HTTPService struct { 54 conf *Config 55 client *http.Client 56 baseURL *url.URL 57 58 cb *gobreaker.TwoStepCircuitBreaker 59 } 60 61 func NewHTTPDialer(c *Config) (Dialer, error) { 62 service, err := NewHTTPService(c) 63 64 if err != nil { 65 return nil, err 66 } 67 68 helper := NewHTTPClientHelper(service) 69 70 return NewInprocessServiceDialer(service, helper), nil 71 } 72 73 func NewHTTPService(c *Config) (*HTTPService, error) { 74 tlsConfig, error := c.TLSConfig() 75 if error != nil { 76 return nil, error 77 } 78 79 client := &http.Client{ 80 Transport: &http.Transport{TLSClientConfig: tlsConfig}, 81 } 82 83 baseURL, err := url.Parse(c.Host) 84 85 if err != nil { 86 return nil, err 87 } 88 89 cb := gobreaker.NewTwoStepCircuitBreaker(gobreaker.Settings{ 90 Name: "httrpc", 91 MaxRequests: 5, 92 Interval: 10 * time.Second, 93 Timeout: 5 * time.Second, 94 ReadyToTrip: func(counts gobreaker.Counts) bool { 95 failureRatio := float64(counts.TotalFailures) / float64(counts.Requests) 96 return counts.Requests >= 10 && failureRatio >= 0.8 97 }, 98 }) 99 100 return &HTTPService{conf: c, client: client, baseURL: baseURL, cb: cb}, nil 101 } 102 103 func (s *HTTPService) Connect(ctx context.Context, r *pb.ConnectionRequest) (*pb.ConnectionResponse, error) { 104 rawResponse, err := s.performRequest(ctx, "connect", utils.ToJSON(r)) 105 106 if err != nil { 107 return nil, err 108 } 109 110 var response pb.ConnectionResponse 111 112 err = json.Unmarshal(rawResponse, &response) 113 114 if err != nil { 115 return nil, err 116 } 117 118 return &response, nil 119 } 120 121 func (s *HTTPService) Disconnect(ctx context.Context, r *pb.DisconnectRequest) (*pb.DisconnectResponse, error) { 122 rawResponse, err := s.performRequest(ctx, "disconnect", utils.ToJSON(r)) 123 124 if err != nil { 125 return nil, err 126 } 127 128 var response pb.DisconnectResponse 129 130 err = json.Unmarshal(rawResponse, &response) 131 132 if err != nil { 133 return nil, err 134 } 135 136 return &response, nil 137 } 138 139 func (s *HTTPService) Command(ctx context.Context, r *pb.CommandMessage) (*pb.CommandResponse, error) { 140 rawResponse, err := s.performRequest(ctx, "command", utils.ToJSON(r)) 141 142 if err != nil { 143 return nil, err 144 } 145 146 var response pb.CommandResponse 147 148 err = json.Unmarshal(rawResponse, &response) 149 150 if err != nil { 151 return nil, err 152 } 153 154 return &response, nil 155 } 156 157 func (s *HTTPService) performRequest(ctx context.Context, path string, payload []byte) ([]byte, error) { 158 cbCallback, err := s.cb.Allow() 159 160 if err != nil { 161 return nil, err 162 } 163 164 url := s.baseURL.JoinPath(path).String() 165 166 // We use timeouts to detect request queueing at the HTTP RPC side and report ResourceExhausted errors 167 // (so adaptive concurrency control can be applied) 168 ctx, cancel := context.WithTimeout(ctx, time.Duration(s.conf.RequestTimeout)*time.Millisecond) 169 defer cancel() 170 171 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload)) 172 if err != nil { 173 return nil, err 174 } 175 176 req.Header.Set("Content-Type", "application/json") 177 178 if s.conf.Secret != "" { 179 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.conf.Secret)) 180 } 181 182 if md, ok := metadata.FromIncomingContext(ctx); ok { 183 // Set headers from metadata 184 for k, v := range md { 185 req.Header.Set(fmt.Sprintf("x-anycable-meta-%s", k), v[0]) 186 } 187 } 188 189 res, err := s.client.Do(req) 190 191 if err != nil { 192 if ctx.Err() != nil { 193 return nil, status.Error(codes.DeadlineExceeded, "request timeout") 194 } 195 196 cbCallback(false) 197 return nil, status.Error(codes.Unavailable, err.Error()) 198 } 199 200 cbCallback(true) 201 202 defer res.Body.Close() 203 204 if res.StatusCode == http.StatusUnauthorized { 205 return nil, status.Error(codes.Unauthenticated, "http returned 401") 206 } 207 208 if res.StatusCode == http.StatusBadRequest || res.StatusCode == http.StatusUnprocessableEntity { 209 reason, rerr := io.ReadAll(res.Body) 210 if rerr != nil { 211 return nil, status.Error(codes.InvalidArgument, "unprocessable entity") 212 } 213 214 return nil, status.Error(codes.InvalidArgument, logger.CompactValue(reason).String()) 215 } 216 217 if res.StatusCode != http.StatusOK { 218 reason, rerr := io.ReadAll(res.Body) 219 if rerr != nil { 220 return nil, status.Error(codes.Unknown, "internal error") 221 } 222 223 return nil, status.Error(codes.Unknown, logger.CompactValue(reason).String()) 224 } 225 226 // Finally, the response is successful, let's read the body 227 rawRequest, err := io.ReadAll(res.Body) 228 229 if err != nil { 230 return nil, err 231 } 232 233 return rawRequest, nil 234 }