github.com/jcarley/cli@v0.0.0-20180201210820-966d90434c30/lib/httpclient/client.go (about)

     1  package httpclient
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"net/http"
    14  	"os"
    15  	"runtime"
    16  	"time"
    17  
    18  	"github.com/Sirupsen/logrus"
    19  	"github.com/daticahealth/cli/config"
    20  	"github.com/daticahealth/cli/lib/updater"
    21  	"github.com/daticahealth/cli/models"
    22  )
    23  
    24  const defaultRedirectLimit = 10
    25  
    26  type TLSHTTPManager struct {
    27  	client *http.Client
    28  }
    29  
    30  // NewTLSHTTPManager constructs and returns a new instance of HTTPManager
    31  // with TLSv1.2 and redirect support.
    32  func NewTLSHTTPManager(skipVerify bool) models.HTTPManager {
    33  	var tr = &http.Transport{
    34  		TLSClientConfig: &tls.Config{
    35  			MinVersion: tls.VersionTLS12,
    36  		},
    37  	}
    38  	if skipVerify {
    39  		tr.TLSClientConfig.InsecureSkipVerify = true
    40  	}
    41  	return &TLSHTTPManager{
    42  		client: &http.Client{
    43  			Transport:     tr,
    44  			CheckRedirect: redirectPolicyFunc,
    45  		},
    46  	}
    47  }
    48  
    49  func redirectPolicyFunc(req *http.Request, via []*http.Request) error {
    50  	if len(via) == 0 {
    51  		// No redirects
    52  		return nil
    53  	}
    54  
    55  	if len(via) > defaultRedirectLimit {
    56  		return fmt.Errorf("%d consecutive requests(redirects)", len(via))
    57  	}
    58  
    59  	// mutate the subsequent redirect requests with the first Header
    60  	for key, val := range via[0].Header {
    61  		req.Header[key] = val
    62  	}
    63  	return nil
    64  }
    65  
    66  // GetHeaders builds a map of headers for a new request.
    67  func (m *TLSHTTPManager) GetHeaders(sessionToken, version, pod, userID string) map[string][]string {
    68  	b := make([]byte, 32)
    69  	rand.Read(b)
    70  	nonce := base64.StdEncoding.EncodeToString(b)
    71  	timestamp := time.Now().Unix()
    72  	return map[string][]string{
    73  		"Accept":              {"application/json"},
    74  		"Content-Type":        {"application/json"},
    75  		"Authorization":       {fmt.Sprintf("Bearer %s", sessionToken)},
    76  		"X-CLI-Version":       {version},
    77  		"X-Pod-ID":            {pod},
    78  		"X-Request-Nonce":     {nonce},
    79  		"X-Request-Timestamp": {fmt.Sprintf("%d", timestamp)},
    80  		"User-Agent":          {fmt.Sprintf("datica-cli-%s %s %s %s", version, runtime.GOOS, config.ArchString(), userID)},
    81  	}
    82  }
    83  
    84  // ConvertResp takes in a resp from one of the httpclient methods and
    85  // checks if it is a successful request. If not, it is parsed as an error object
    86  // and returned as an error. Otherwise it will be marshalled into the requested
    87  // interface. ALWAYS PASS A POINTER INTO THIS METHOD. If you don't pass a struct
    88  // pointer your original object will be nil or an empty struct.
    89  func (m *TLSHTTPManager) ConvertResp(b []byte, statusCode int, s interface{}) error {
    90  	logrus.Debugf("%d resp: %s", statusCode, string(b))
    91  	if m.isError(statusCode) {
    92  		return m.convertError(b, statusCode)
    93  	}
    94  	if b == nil || len(b) == 0 || s == nil {
    95  		return nil
    96  	}
    97  	return json.Unmarshal(b, s)
    98  }
    99  
   100  // ConvertError takes in a response from one of the httpclient methods and converts it
   101  // to a usable error object.
   102  func (m *TLSHTTPManager) ConvertError(b []byte, statusCode int) (*models.Error, error) {
   103  	if !m.isError(statusCode) {
   104  		return nil, errors.New("tried to convert a non-error response into an error")
   105  	}
   106  	var resp models.Error
   107  	err := json.Unmarshal(b, &resp)
   108  	return &resp, err
   109  }
   110  
   111  // isError checks if an HTTP response code is outside of the "OK" range.
   112  func (m *TLSHTTPManager) isError(statusCode int) bool {
   113  	return statusCode < 200 || statusCode >= 300
   114  }
   115  
   116  // convertError attempts to convert a response into a usable error object.
   117  func (m *TLSHTTPManager) convertError(b []byte, statusCode int) error {
   118  	msg := fmt.Sprintf("(%d)", statusCode)
   119  	if b != nil && len(b) > 0 {
   120  		var errs models.Error
   121  		unmarshalErr := json.Unmarshal(b, &errs)
   122  		if unmarshalErr == nil && errs.Title != "" && errs.Description != "" {
   123  			msg = fmt.Sprintf("(%d) %s: %s", errs.Code, errs.Title, errs.Description)
   124  		} else {
   125  			var reportedErr models.ReportedError
   126  			unmarshalErr = json.Unmarshal(b, &reportedErr)
   127  			if unmarshalErr == nil && reportedErr.Message != "" {
   128  				msg = fmt.Sprintf("(%d) %s", reportedErr.Code, reportedErr.Message)
   129  			} else {
   130  				msg = fmt.Sprintf("(%d) %s", statusCode, string(b))
   131  			}
   132  		}
   133  	}
   134  	return errors.New(msg)
   135  }
   136  
   137  // Get performs a GET request
   138  func (m *TLSHTTPManager) Get(body []byte, url string, headers map[string][]string) ([]byte, int, error) {
   139  	reader := bytes.NewReader(body)
   140  	return m.makeRequest("GET", url, reader, headers)
   141  }
   142  
   143  // Post performs a POST request
   144  func (m *TLSHTTPManager) Post(body []byte, url string, headers map[string][]string) ([]byte, int, error) {
   145  	reader := bytes.NewReader(body)
   146  	return m.makeRequest("POST", url, reader, headers)
   147  }
   148  
   149  // PostFile uploads a file with a POST
   150  func (m *TLSHTTPManager) PostFile(filepath string, url string, headers map[string][]string) ([]byte, int, error) {
   151  	return m.uploadFile("POST", filepath, url, headers)
   152  }
   153  
   154  // PutFile uploads a file with a PUT
   155  func (m *TLSHTTPManager) PutFile(filepath string, url string, headers map[string][]string) ([]byte, int, error) {
   156  	return m.uploadFile("PUT", filepath, url, headers)
   157  }
   158  
   159  func (m *TLSHTTPManager) uploadFile(method, filepath, url string, headers map[string][]string) ([]byte, int, error) {
   160  	logrus.Debugf("%s %s", method, url)
   161  	logrus.Debugf("%+v", headers)
   162  	logrus.Debugf("%s", filepath)
   163  	file, err := os.Open(filepath)
   164  	defer file.Close()
   165  	if err != nil {
   166  		return nil, 0, err
   167  	}
   168  	info, _ := file.Stat()
   169  	req, _ := http.NewRequest(method, url, file)
   170  	req.ContentLength = info.Size()
   171  
   172  	resp, err := m.client.Do(req)
   173  	if err != nil {
   174  		return nil, 0, err
   175  	}
   176  	defer resp.Body.Close()
   177  	respBody, _ := ioutil.ReadAll(resp.Body)
   178  	return respBody, resp.StatusCode, nil
   179  }
   180  
   181  // Put performs a PUT request
   182  func (m *TLSHTTPManager) Put(body []byte, url string, headers map[string][]string) ([]byte, int, error) {
   183  	reader := bytes.NewReader(body)
   184  	return m.makeRequest("PUT", url, reader, headers)
   185  }
   186  
   187  // Delete performs a DELETE request
   188  func (m *TLSHTTPManager) Delete(body []byte, url string, headers map[string][]string) ([]byte, int, error) {
   189  	reader := bytes.NewReader(body)
   190  	return m.makeRequest("DELETE", url, reader, headers)
   191  }
   192  
   193  // MakeRequest is a generic HTTP runner that performs a request and returns
   194  // the result body as a byte array. It's up to the caller to transform them
   195  // into an object.
   196  func (m *TLSHTTPManager) makeRequest(method string, url string, body io.Reader, headers map[string][]string) ([]byte, int, error) {
   197  	logrus.Debugf("%s %s", method, url)
   198  	logrus.Debugf("%+v", headers)
   199  	logrus.Debugf("%s", body)
   200  	req, _ := http.NewRequest(method, url, body)
   201  	req.Header = headers
   202  
   203  	resp, err := m.client.Do(req)
   204  	if err != nil {
   205  		return nil, 0, err
   206  	}
   207  	defer resp.Body.Close()
   208  	respBody, _ := ioutil.ReadAll(resp.Body)
   209  	if resp.StatusCode == 412 {
   210  		updater.AutoUpdater.ForcedUpgrade()
   211  		return nil, 0, fmt.Errorf("A required update has been applied. Please re-run this command.")
   212  	}
   213  	return respBody, resp.StatusCode, nil
   214  }