github.com/alloyci/alloy-runner@v1.0.1-0.20180222164613-925503ccafd6/network/client.go (about) 1 package network 2 3 import ( 4 "bytes" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/hex" 8 "encoding/json" 9 "encoding/pem" 10 "errors" 11 "fmt" 12 "io" 13 "io/ioutil" 14 "mime" 15 "net" 16 "net/http" 17 "net/url" 18 "os" 19 "path/filepath" 20 "strings" 21 "sync" 22 "time" 23 24 "github.com/Sirupsen/logrus" 25 "github.com/jpillora/backoff" 26 27 "gitlab.com/gitlab-org/gitlab-runner/common" 28 ) 29 30 type requestCredentials interface { 31 GetURL() string 32 GetToken() string 33 GetTLSCAFile() string 34 GetTLSCertFile() string 35 GetTLSKeyFile() string 36 } 37 38 var ( 39 dialer = net.Dialer{ 40 Timeout: 30 * time.Second, 41 KeepAlive: 30 * time.Second, 42 } 43 44 backOffDelayMin = 100 * time.Millisecond 45 backOffDelayMax = 60 * time.Second 46 backOffDelayFactor = 2.0 47 backOffDelayJitter = true 48 ) 49 50 type client struct { 51 http.Client 52 url *url.URL 53 caFile string 54 certFile string 55 keyFile string 56 caData []byte 57 skipVerify bool 58 updateTime time.Time 59 lastUpdate string 60 requestBackOffs map[string]*backoff.Backoff 61 lock sync.Mutex 62 } 63 64 type ResponseTLSData struct { 65 CAChain string 66 CertFile string 67 KeyFile string 68 } 69 70 func (n *client) getLastUpdate() string { 71 return n.lastUpdate 72 } 73 74 func (n *client) setLastUpdate(headers http.Header) { 75 if lu := headers.Get("X-GitLab-Last-Update"); len(lu) > 0 { 76 n.lastUpdate = lu 77 } 78 } 79 80 func (n *client) ensureTLSConfig() { 81 // certificate got modified 82 if stat, err := os.Stat(n.caFile); err == nil && n.updateTime.Before(stat.ModTime()) { 83 n.Transport = nil 84 } 85 86 // client certificate got modified 87 if stat, err := os.Stat(n.certFile); err == nil && n.updateTime.Before(stat.ModTime()) { 88 n.Transport = nil 89 } 90 91 // client private key got modified 92 if stat, err := os.Stat(n.keyFile); err == nil && n.updateTime.Before(stat.ModTime()) { 93 n.Transport = nil 94 } 95 96 // create or update transport 97 if n.Transport == nil { 98 n.updateTime = time.Now() 99 n.createTransport() 100 } 101 } 102 103 func (n *client) addTLSCA(tlsConfig *tls.Config) { 104 // load TLS CA certificate 105 if file := n.caFile; file != "" && !n.skipVerify { 106 logrus.Debugln("Trying to load", file, "...") 107 108 data, err := ioutil.ReadFile(file) 109 if err == nil { 110 pool, err := x509.SystemCertPool() 111 if err != nil { 112 logrus.Warningln("Failed to load system CertPool:", err) 113 } 114 if pool == nil { 115 pool = x509.NewCertPool() 116 } 117 if pool.AppendCertsFromPEM(data) { 118 tlsConfig.RootCAs = pool 119 n.caData = data 120 } else { 121 logrus.Errorln("Failed to parse PEM in", n.caFile) 122 } 123 } else { 124 if !os.IsNotExist(err) { 125 logrus.Errorln("Failed to load", n.caFile, err) 126 } 127 } 128 } 129 } 130 131 func (n *client) addTLSAuth(tlsConfig *tls.Config) { 132 // load TLS client keypair 133 if cert, key := n.certFile, n.keyFile; cert != "" && key != "" { 134 logrus.Debugln("Trying to load", cert, "and", key, "pair...") 135 136 certificate, err := tls.LoadX509KeyPair(cert, key) 137 if err == nil { 138 tlsConfig.Certificates = []tls.Certificate{certificate} 139 tlsConfig.BuildNameToCertificate() 140 } else { 141 if !os.IsNotExist(err) { 142 logrus.Errorln("Failed to load", cert, key, err) 143 } 144 } 145 } 146 } 147 148 func (n *client) createTransport() { 149 // create reference TLS config 150 tlsConfig := tls.Config{ 151 MinVersion: tls.VersionTLS10, 152 InsecureSkipVerify: n.skipVerify, 153 } 154 155 n.addTLSCA(&tlsConfig) 156 n.addTLSAuth(&tlsConfig) 157 158 // create transport 159 n.Transport = &http.Transport{ 160 Proxy: http.ProxyFromEnvironment, 161 Dial: func(network, addr string) (net.Conn, error) { 162 logrus.Debugln("Dialing:", network, addr, "...") 163 return dialer.Dial(network, addr) 164 }, 165 TLSClientConfig: &tlsConfig, 166 MaxIdleConns: 100, 167 IdleConnTimeout: 90 * time.Second, 168 TLSHandshakeTimeout: 10 * time.Second, 169 ExpectContinueTimeout: 1 * time.Second, 170 ResponseHeaderTimeout: 10 * time.Minute, 171 } 172 n.Timeout = common.DefaultNetworkClientTimeout 173 } 174 175 func (n *client) getCAChain(tls *tls.ConnectionState) string { 176 if len(n.caData) != 0 { 177 return string(n.caData) 178 } 179 180 if tls == nil { 181 return "" 182 } 183 184 // Don't reorder certificates by putting them directly into the map 185 var certificates []*x509.Certificate 186 seenCertificates := make(map[string]bool, 0) 187 188 for _, verifiedChain := range tls.VerifiedChains { 189 for _, certificate := range verifiedChain { 190 signature := hex.EncodeToString(certificate.Signature) 191 if seenCertificates[signature] { 192 continue 193 } 194 195 seenCertificates[signature] = true 196 certificates = append(certificates, certificate) 197 } 198 } 199 200 out := bytes.NewBuffer(nil) 201 for _, certificate := range certificates { 202 if err := pem.Encode(out, &pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw}); err != nil { 203 logrus.Warn("Failed to encode certificate from chain:", err) 204 } 205 } 206 207 return out.String() 208 } 209 210 func (n *client) ensureBackoff(method, uri string) *backoff.Backoff { 211 n.lock.Lock() 212 defer n.lock.Unlock() 213 214 key := fmt.Sprintf("%s_%s", method, uri) 215 if n.requestBackOffs[key] == nil { 216 n.requestBackOffs[key] = &backoff.Backoff{ 217 Min: backOffDelayMin, 218 Max: backOffDelayMax, 219 Factor: backOffDelayFactor, 220 Jitter: backOffDelayJitter, 221 } 222 } 223 224 return n.requestBackOffs[key] 225 } 226 227 func (n *client) backoffRequired(res *http.Response) bool { 228 return res.StatusCode >= 400 && res.StatusCode < 600 229 } 230 231 func (n *client) doBackoffRequest(req *http.Request) (res *http.Response, err error) { 232 res, err = n.Do(req) 233 if err != nil { 234 err = fmt.Errorf("couldn't execute %v against %s: %v", req.Method, req.URL, err) 235 return 236 } 237 238 backoffDelay := n.ensureBackoff(req.Method, req.RequestURI) 239 if n.backoffRequired(res) { 240 time.Sleep(backoffDelay.Duration()) 241 } else { 242 backoffDelay.Reset() 243 } 244 245 return 246 } 247 248 func (n *client) do(uri, method string, request io.Reader, requestType string, headers http.Header) (res *http.Response, err error) { 249 url, err := n.url.Parse(uri) 250 if err != nil { 251 return 252 } 253 254 req, err := http.NewRequest(method, url.String(), request) 255 if err != nil { 256 err = fmt.Errorf("failed to create NewRequest: %v", err) 257 return 258 } 259 260 if headers != nil { 261 req.Header = headers 262 } 263 264 if request != nil { 265 req.Header.Set("Content-Type", requestType) 266 req.Header.Set("User-Agent", common.AppVersion.UserAgent()) 267 } 268 269 n.ensureTLSConfig() 270 271 res, err = n.doBackoffRequest(req) 272 return 273 } 274 275 func (n *client) doJSON(uri, method string, statusCode int, request interface{}, response interface{}) (int, string, ResponseTLSData) { 276 var body io.Reader 277 278 if request != nil { 279 requestBody, err := json.Marshal(request) 280 if err != nil { 281 return -1, fmt.Sprintf("failed to marshal project object: %v", err), ResponseTLSData{} 282 } 283 body = bytes.NewReader(requestBody) 284 } 285 286 headers := make(http.Header) 287 if response != nil { 288 headers.Set("Accept", "application/json") 289 } 290 291 res, err := n.do(uri, method, body, "application/json", headers) 292 if err != nil { 293 return -1, err.Error(), ResponseTLSData{} 294 } 295 defer res.Body.Close() 296 defer io.Copy(ioutil.Discard, res.Body) 297 298 if res.StatusCode == statusCode { 299 if response != nil { 300 isApplicationJSON, err := isResponseApplicationJSON(res) 301 if !isApplicationJSON { 302 return -1, err.Error(), ResponseTLSData{} 303 } 304 305 d := json.NewDecoder(res.Body) 306 err = d.Decode(response) 307 if err != nil { 308 return -1, fmt.Sprintf("Error decoding json payload %v", err), ResponseTLSData{} 309 } 310 } 311 } 312 313 n.setLastUpdate(res.Header) 314 315 TLSData := ResponseTLSData{ 316 CAChain: n.getCAChain(res.TLS), 317 CertFile: n.certFile, 318 KeyFile: n.keyFile, 319 } 320 321 return res.StatusCode, res.Status, TLSData 322 } 323 324 func isResponseApplicationJSON(res *http.Response) (result bool, err error) { 325 contentType := res.Header.Get("Content-Type") 326 327 mimetype, _, err := mime.ParseMediaType(contentType) 328 if err != nil { 329 return false, fmt.Errorf("Content-Type parsing error: %v", err) 330 } 331 332 if mimetype != "application/json" { 333 return false, fmt.Errorf("Server should return application/json. Got: %v", contentType) 334 } 335 336 return true, nil 337 } 338 339 func fixCIURL(url string) string { 340 url = strings.TrimRight(url, "/") 341 if strings.HasSuffix(url, "/ci") { 342 url = strings.TrimSuffix(url, "/ci") 343 } 344 return url 345 } 346 347 func (n *client) findCertificate(certificate *string, base string, name string) { 348 if *certificate != "" { 349 return 350 } 351 path := filepath.Join(base, name) 352 if _, err := os.Stat(path); err == nil { 353 *certificate = path 354 } 355 } 356 357 func newClient(requestCredentials requestCredentials) (c *client, err error) { 358 url, err := url.Parse(fixCIURL(requestCredentials.GetURL()) + "/api/v4/") 359 if err != nil { 360 return 361 } 362 363 if url.Scheme != "http" && url.Scheme != "https" { 364 err = errors.New("only http or https scheme supported") 365 return 366 } 367 368 c = &client{ 369 url: url, 370 caFile: requestCredentials.GetTLSCAFile(), 371 certFile: requestCredentials.GetTLSCertFile(), 372 keyFile: requestCredentials.GetTLSKeyFile(), 373 requestBackOffs: make(map[string]*backoff.Backoff), 374 } 375 376 host := strings.Split(url.Host, ":")[0] 377 if CertificateDirectory != "" { 378 c.findCertificate(&c.caFile, CertificateDirectory, host+".crt") 379 c.findCertificate(&c.certFile, CertificateDirectory, host+".auth.crt") 380 c.findCertificate(&c.keyFile, CertificateDirectory, host+".auth.key") 381 } 382 383 return 384 }