github.com/pivotal-cf/go-pivnet/v6@v6.0.2/pivnet.go (about)

     1  package pivnet
     2  
     3  import (
     4  	"crypto/tls"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"log"
    10  	"net/http"
    11  	"net/http/httputil"
    12  	"net/url"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/pivotal-cf/go-pivnet/v6/download"
    17  	"github.com/pivotal-cf/go-pivnet/v6/logger"
    18  )
    19  
    20  const (
    21  	DefaultHost         = "https://network.pivotal.io"
    22  	apiVersion          = "/api/v2"
    23  	concurrentDownloads = 10
    24  )
    25  
    26  type Client struct {
    27  	baseURL       string
    28  	token         AccessTokenService
    29  	userAgent     string
    30  	logger        logger.Logger
    31  	usingUAAToken bool
    32  
    33  	HTTP *http.Client
    34  
    35  	downloader download.Client
    36  
    37  	Auth                 *AuthService
    38  	EULA                 *EULAsService
    39  	ProductFiles         *ProductFilesService
    40  	ArtifactReferences   *ArtifactReferencesService
    41  	FederationToken      *FederationTokenService
    42  	FileGroups           *FileGroupsService
    43  	Releases             *ReleasesService
    44  	Products             *ProductsService
    45  	UserGroups           *UserGroupsService
    46  	SubscriptionGroups   *SubscriptionGroupsService
    47  	ReleaseTypes         *ReleaseTypesService
    48  	ReleaseDependencies  *ReleaseDependenciesService
    49  	DependencySpecifiers *DependencySpecifiersService
    50  	ReleaseUpgradePaths   *ReleaseUpgradePathsService
    51  	UpgradePathSpecifiers *UpgradePathSpecifiersService
    52  	PivnetVersions        *PivnetVersionsService
    53  }
    54  
    55  type AccessTokenOrLegacyToken struct {
    56  	host              string
    57  	refreshToken      string
    58  	skipSSLValidation bool
    59  	userAgent         string
    60  }
    61  
    62  type QueryParameter struct {
    63  	Key string
    64  	Value string
    65  }
    66  
    67  func (o AccessTokenOrLegacyToken) AccessToken() (string, error) {
    68  	const legacyAPITokenLength = 20
    69  	if len(o.refreshToken) > legacyAPITokenLength {
    70  		baseURL := fmt.Sprintf("%s%s", o.host, apiVersion)
    71  		tokenFetcher := NewTokenFetcher(baseURL, o.refreshToken, o.skipSSLValidation, o.userAgent)
    72  
    73  		accessToken, err := tokenFetcher.GetToken()
    74  		if err != nil {
    75  			log.Panicf("Exiting with error: %s", err)
    76  			return "", err
    77  		}
    78  		return accessToken, nil
    79  	} else {
    80  		return o.refreshToken, nil
    81  	}
    82  }
    83  
    84  func AuthorizationHeader(accessToken string) (string, error) {
    85  	const legacyAPITokenLength = 20
    86  	if len(accessToken) > legacyAPITokenLength {
    87  		return fmt.Sprintf("Bearer %s", accessToken), nil
    88  	} else {
    89  		return fmt.Sprintf("Token %s", accessToken), nil
    90  	}
    91  }
    92  
    93  type ClientConfig struct {
    94  	Host              string
    95  	UserAgent         string
    96  	SkipSSLValidation bool
    97  }
    98  
    99  //go:generate counterfeiter . AccessTokenService
   100  type AccessTokenService interface {
   101  	AccessToken() (string, error)
   102  }
   103  
   104  func NewAccessTokenOrLegacyToken(token string, host string, skipSSLValidation bool, userAgentOptional ...string) AccessTokenOrLegacyToken {
   105  	var userAgent = ""
   106  	if len(userAgentOptional) > 0 {
   107  		userAgent = userAgentOptional[0]
   108  	}
   109  	return AccessTokenOrLegacyToken{
   110  		refreshToken:      token,
   111  		host:              host,
   112  		skipSSLValidation: skipSSLValidation,
   113  		userAgent:         userAgent,
   114  	}
   115  }
   116  
   117  func NewClient(
   118  	token AccessTokenService,
   119  	config ClientConfig,
   120  	logger logger.Logger,
   121  ) Client {
   122  	baseURL := fmt.Sprintf("%s%s", config.Host, apiVersion)
   123  
   124  	httpClient := &http.Client{
   125  		Timeout: 60 * time.Second,
   126  		Transport: &http.Transport{
   127  			TLSClientConfig: &tls.Config{
   128  				InsecureSkipVerify: config.SkipSSLValidation,
   129  			},
   130  			Proxy: http.ProxyFromEnvironment,
   131  		},
   132  	}
   133  
   134  	downloadClient := &http.Client{
   135  		Timeout: 0,
   136  		Transport: &http.Transport{
   137  			TLSClientConfig: &tls.Config{
   138  				InsecureSkipVerify: config.SkipSSLValidation,
   139  			},
   140  			Proxy: http.ProxyFromEnvironment,
   141  		},
   142  	}
   143  
   144  	ranger := download.NewRanger(concurrentDownloads)
   145  	downloader := download.Client{
   146  		HTTPClient: downloadClient,
   147  		Ranger:     ranger,
   148  		Logger:     logger,
   149  		Timeout:    5 * time.Second,
   150  	}
   151  
   152  	client := Client{
   153  		baseURL:    baseURL,
   154  		token:      token,
   155  		userAgent:  config.UserAgent,
   156  		logger:     logger,
   157  		downloader: downloader,
   158  		HTTP:       httpClient,
   159  	}
   160  
   161  	client.Auth = &AuthService{client: client}
   162  	client.EULA = &EULAsService{client: client}
   163  	client.ProductFiles = &ProductFilesService{client: client}
   164  	client.ArtifactReferences = &ArtifactReferencesService{client: client}
   165  	client.FederationToken = &FederationTokenService{client: client}
   166  	client.FileGroups = &FileGroupsService{client: client}
   167  	client.Releases = &ReleasesService{client: client, l: logger}
   168  	client.Products = &ProductsService{client: client, l: logger}
   169  	client.UserGroups = &UserGroupsService{client: client}
   170  	client.SubscriptionGroups = &SubscriptionGroupsService{client: client}
   171  	client.ReleaseTypes = &ReleaseTypesService{client: client}
   172  	client.ReleaseDependencies = &ReleaseDependenciesService{client: client}
   173  	client.DependencySpecifiers = &DependencySpecifiersService{client: client}
   174  	client.ReleaseUpgradePaths = &ReleaseUpgradePathsService{client: client}
   175  	client.UpgradePathSpecifiers = &UpgradePathSpecifiersService{client: client}
   176  	client.PivnetVersions = &PivnetVersionsService{client: client}
   177  
   178  	return client
   179  }
   180  
   181  func (c Client) CreateRequest(
   182  	requestType string,
   183  	endpoint string,
   184  	body io.Reader,
   185  ) (*http.Request, error) {
   186  	u, err := url.Parse(c.baseURL)
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  
   191  	endpoint = c.stripHostPrefix(endpoint)
   192  
   193  	u.Path = u.Path + endpoint
   194  
   195  	req, err := http.NewRequest(requestType, u.String(), body)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  
   200  	if !isVersionsEndpoint(endpoint) {
   201  		accessToken, err := c.token.AccessToken()
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  
   206  		authorizationHeader, err := AuthorizationHeader(accessToken)
   207  		if err != nil {
   208  			return nil, fmt.Errorf("could not create authorization header: %s", err)
   209  		}
   210  
   211  		req.Header.Add("Authorization", authorizationHeader)
   212  	}
   213  
   214  	req.Header.Add("Content-Type", "application/json")
   215  	req.Header.Add("User-Agent", c.userAgent)
   216  
   217  	return req, nil
   218  }
   219  
   220  func (c Client) MakeRequest(
   221  	requestType string,
   222  	endpoint string,
   223  	expectedStatusCode int,
   224  	body io.Reader,
   225  ) (*http.Response, error) {
   226  	req, err := c.CreateRequest(requestType, endpoint, body)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  
   231  	reqBytes, err := httputil.DumpRequestOut(req, true)
   232  	if err != nil {
   233  		return nil, err
   234  	}
   235  
   236  	c.logger.Debug("Making request", logger.Data{"request": string(reqBytes)})
   237  
   238  	resp, err := c.HTTP.Do(req)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	c.logger.Debug("Response status code", logger.Data{"status code": resp.StatusCode})
   244  	c.logger.Debug("Response headers", logger.Data{"headers": resp.Header})
   245  
   246  	if expectedStatusCode > 0 && resp.StatusCode != expectedStatusCode {
   247  		return nil, c.handleUnexpectedResponse(resp)
   248  	}
   249  
   250  	return resp, nil
   251  }
   252  
   253  func (c Client) MakeRequestWithParams(
   254  	requestType string,
   255  	endpoint string,
   256  	expectedStatusCode int,
   257  	params []QueryParameter,
   258  	body io.Reader,
   259  ) (*http.Response, error) {
   260  	req, err := c.CreateRequest(requestType, endpoint, body)
   261  	if err != nil {
   262  		return nil, err
   263  	}
   264  
   265  	q := req.URL.Query()
   266  	for _, param := range params {
   267  		q.Add(param.Key, param.Value)
   268  	}
   269  	req.URL.RawQuery = q.Encode()
   270  
   271  	reqBytes, err := httputil.DumpRequestOut(req, true)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  
   276  	c.logger.Debug("Making request", logger.Data{"request": string(reqBytes)})
   277  
   278  	resp, err := c.HTTP.Do(req)
   279  	if err != nil {
   280  		return nil, err
   281  	}
   282  
   283  	c.logger.Debug("Response status code", logger.Data{"status code": resp.StatusCode})
   284  	c.logger.Debug("Response headers", logger.Data{"headers": resp.Header})
   285  
   286  	if expectedStatusCode > 0 && resp.StatusCode != expectedStatusCode {
   287  		return nil, c.handleUnexpectedResponse(resp)
   288  	}
   289  
   290  	return resp, nil
   291  }
   292  
   293  func (c Client) stripHostPrefix(downloadLink string) string {
   294  	if strings.HasPrefix(downloadLink, apiVersion) {
   295  		return downloadLink
   296  	}
   297  	sp := strings.Split(downloadLink, apiVersion)
   298  	return sp[len(sp)-1]
   299  }
   300  
   301  func (c Client) handleUnexpectedResponse(resp *http.Response) error {
   302  	var pErr pivnetErr
   303  
   304  	b, err := ioutil.ReadAll(resp.Body)
   305  	if err != nil {
   306  		return err
   307  	}
   308  
   309  	if resp.StatusCode == http.StatusTooManyRequests {
   310  		return newErrTooManyRequests()
   311  	}
   312  
   313  	// We have to handle 500 differently because it has a different structure
   314  	if resp.StatusCode == http.StatusInternalServerError {
   315  		var internalServerError pivnetInternalServerErr
   316  		err = json.Unmarshal(b, &internalServerError)
   317  		if err != nil {
   318  			return err
   319  		}
   320  
   321  		pErr = pivnetErr{
   322  			Message: internalServerError.Error,
   323  		}
   324  	} else {
   325  		err = json.Unmarshal(b, &pErr)
   326  		if err != nil {
   327  			return fmt.Errorf("could not parse json [%q] \n%s", b, err)
   328  		}
   329  	}
   330  
   331  	switch resp.StatusCode {
   332  	case http.StatusUnauthorized:
   333  		return newErrUnauthorized(pErr.Message)
   334  	case http.StatusNotFound:
   335  		return newErrNotFound(pErr.Message)
   336  	case http.StatusUnavailableForLegalReasons:
   337  		return newErrUnavailableForLegalReasons(pErr.Message)
   338  	default:
   339  		return ErrPivnetOther{
   340  			ResponseCode: resp.StatusCode,
   341  			Message:      pErr.Message,
   342  			Errors:       pErr.Errors,
   343  		}
   344  	}
   345  }
   346  
   347  func isVersionsEndpoint(endpoint string) bool {
   348  	return endpoint == "/versions"
   349  }