github.com/pivotal-cf/go-pivnet/v6@v6.0.2/pivnet.go (about) 1 package pivnet 2 3 import ( 4 "crypto/tls" 5 "encoding/json" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "log" 10 "net/http" 11 "net/http/httputil" 12 "net/url" 13 "strings" 14 "time" 15 16 "github.com/pivotal-cf/go-pivnet/v6/download" 17 "github.com/pivotal-cf/go-pivnet/v6/logger" 18 ) 19 20 const ( 21 DefaultHost = "https://network.pivotal.io" 22 apiVersion = "/api/v2" 23 concurrentDownloads = 10 24 ) 25 26 type Client struct { 27 baseURL string 28 token AccessTokenService 29 userAgent string 30 logger logger.Logger 31 usingUAAToken bool 32 33 HTTP *http.Client 34 35 downloader download.Client 36 37 Auth *AuthService 38 EULA *EULAsService 39 ProductFiles *ProductFilesService 40 ArtifactReferences *ArtifactReferencesService 41 FederationToken *FederationTokenService 42 FileGroups *FileGroupsService 43 Releases *ReleasesService 44 Products *ProductsService 45 UserGroups *UserGroupsService 46 SubscriptionGroups *SubscriptionGroupsService 47 ReleaseTypes *ReleaseTypesService 48 ReleaseDependencies *ReleaseDependenciesService 49 DependencySpecifiers *DependencySpecifiersService 50 ReleaseUpgradePaths *ReleaseUpgradePathsService 51 UpgradePathSpecifiers *UpgradePathSpecifiersService 52 PivnetVersions *PivnetVersionsService 53 } 54 55 type AccessTokenOrLegacyToken struct { 56 host string 57 refreshToken string 58 skipSSLValidation bool 59 userAgent string 60 } 61 62 type QueryParameter struct { 63 Key string 64 Value string 65 } 66 67 func (o AccessTokenOrLegacyToken) AccessToken() (string, error) { 68 const legacyAPITokenLength = 20 69 if len(o.refreshToken) > legacyAPITokenLength { 70 baseURL := fmt.Sprintf("%s%s", o.host, apiVersion) 71 tokenFetcher := NewTokenFetcher(baseURL, o.refreshToken, o.skipSSLValidation, o.userAgent) 72 73 accessToken, err := tokenFetcher.GetToken() 74 if err != nil { 75 log.Panicf("Exiting with error: %s", err) 76 return "", err 77 } 78 return accessToken, nil 79 } else { 80 return o.refreshToken, nil 81 } 82 } 83 84 func AuthorizationHeader(accessToken string) (string, error) { 85 const legacyAPITokenLength = 20 86 if len(accessToken) > legacyAPITokenLength { 87 return fmt.Sprintf("Bearer %s", accessToken), nil 88 } else { 89 return fmt.Sprintf("Token %s", accessToken), nil 90 } 91 } 92 93 type ClientConfig struct { 94 Host string 95 UserAgent string 96 SkipSSLValidation bool 97 } 98 99 //go:generate counterfeiter . AccessTokenService 100 type AccessTokenService interface { 101 AccessToken() (string, error) 102 } 103 104 func NewAccessTokenOrLegacyToken(token string, host string, skipSSLValidation bool, userAgentOptional ...string) AccessTokenOrLegacyToken { 105 var userAgent = "" 106 if len(userAgentOptional) > 0 { 107 userAgent = userAgentOptional[0] 108 } 109 return AccessTokenOrLegacyToken{ 110 refreshToken: token, 111 host: host, 112 skipSSLValidation: skipSSLValidation, 113 userAgent: userAgent, 114 } 115 } 116 117 func NewClient( 118 token AccessTokenService, 119 config ClientConfig, 120 logger logger.Logger, 121 ) Client { 122 baseURL := fmt.Sprintf("%s%s", config.Host, apiVersion) 123 124 httpClient := &http.Client{ 125 Timeout: 60 * time.Second, 126 Transport: &http.Transport{ 127 TLSClientConfig: &tls.Config{ 128 InsecureSkipVerify: config.SkipSSLValidation, 129 }, 130 Proxy: http.ProxyFromEnvironment, 131 }, 132 } 133 134 downloadClient := &http.Client{ 135 Timeout: 0, 136 Transport: &http.Transport{ 137 TLSClientConfig: &tls.Config{ 138 InsecureSkipVerify: config.SkipSSLValidation, 139 }, 140 Proxy: http.ProxyFromEnvironment, 141 }, 142 } 143 144 ranger := download.NewRanger(concurrentDownloads) 145 downloader := download.Client{ 146 HTTPClient: downloadClient, 147 Ranger: ranger, 148 Logger: logger, 149 Timeout: 5 * time.Second, 150 } 151 152 client := Client{ 153 baseURL: baseURL, 154 token: token, 155 userAgent: config.UserAgent, 156 logger: logger, 157 downloader: downloader, 158 HTTP: httpClient, 159 } 160 161 client.Auth = &AuthService{client: client} 162 client.EULA = &EULAsService{client: client} 163 client.ProductFiles = &ProductFilesService{client: client} 164 client.ArtifactReferences = &ArtifactReferencesService{client: client} 165 client.FederationToken = &FederationTokenService{client: client} 166 client.FileGroups = &FileGroupsService{client: client} 167 client.Releases = &ReleasesService{client: client, l: logger} 168 client.Products = &ProductsService{client: client, l: logger} 169 client.UserGroups = &UserGroupsService{client: client} 170 client.SubscriptionGroups = &SubscriptionGroupsService{client: client} 171 client.ReleaseTypes = &ReleaseTypesService{client: client} 172 client.ReleaseDependencies = &ReleaseDependenciesService{client: client} 173 client.DependencySpecifiers = &DependencySpecifiersService{client: client} 174 client.ReleaseUpgradePaths = &ReleaseUpgradePathsService{client: client} 175 client.UpgradePathSpecifiers = &UpgradePathSpecifiersService{client: client} 176 client.PivnetVersions = &PivnetVersionsService{client: client} 177 178 return client 179 } 180 181 func (c Client) CreateRequest( 182 requestType string, 183 endpoint string, 184 body io.Reader, 185 ) (*http.Request, error) { 186 u, err := url.Parse(c.baseURL) 187 if err != nil { 188 return nil, err 189 } 190 191 endpoint = c.stripHostPrefix(endpoint) 192 193 u.Path = u.Path + endpoint 194 195 req, err := http.NewRequest(requestType, u.String(), body) 196 if err != nil { 197 return nil, err 198 } 199 200 if !isVersionsEndpoint(endpoint) { 201 accessToken, err := c.token.AccessToken() 202 if err != nil { 203 return nil, err 204 } 205 206 authorizationHeader, err := AuthorizationHeader(accessToken) 207 if err != nil { 208 return nil, fmt.Errorf("could not create authorization header: %s", err) 209 } 210 211 req.Header.Add("Authorization", authorizationHeader) 212 } 213 214 req.Header.Add("Content-Type", "application/json") 215 req.Header.Add("User-Agent", c.userAgent) 216 217 return req, nil 218 } 219 220 func (c Client) MakeRequest( 221 requestType string, 222 endpoint string, 223 expectedStatusCode int, 224 body io.Reader, 225 ) (*http.Response, error) { 226 req, err := c.CreateRequest(requestType, endpoint, body) 227 if err != nil { 228 return nil, err 229 } 230 231 reqBytes, err := httputil.DumpRequestOut(req, true) 232 if err != nil { 233 return nil, err 234 } 235 236 c.logger.Debug("Making request", logger.Data{"request": string(reqBytes)}) 237 238 resp, err := c.HTTP.Do(req) 239 if err != nil { 240 return nil, err 241 } 242 243 c.logger.Debug("Response status code", logger.Data{"status code": resp.StatusCode}) 244 c.logger.Debug("Response headers", logger.Data{"headers": resp.Header}) 245 246 if expectedStatusCode > 0 && resp.StatusCode != expectedStatusCode { 247 return nil, c.handleUnexpectedResponse(resp) 248 } 249 250 return resp, nil 251 } 252 253 func (c Client) MakeRequestWithParams( 254 requestType string, 255 endpoint string, 256 expectedStatusCode int, 257 params []QueryParameter, 258 body io.Reader, 259 ) (*http.Response, error) { 260 req, err := c.CreateRequest(requestType, endpoint, body) 261 if err != nil { 262 return nil, err 263 } 264 265 q := req.URL.Query() 266 for _, param := range params { 267 q.Add(param.Key, param.Value) 268 } 269 req.URL.RawQuery = q.Encode() 270 271 reqBytes, err := httputil.DumpRequestOut(req, true) 272 if err != nil { 273 return nil, err 274 } 275 276 c.logger.Debug("Making request", logger.Data{"request": string(reqBytes)}) 277 278 resp, err := c.HTTP.Do(req) 279 if err != nil { 280 return nil, err 281 } 282 283 c.logger.Debug("Response status code", logger.Data{"status code": resp.StatusCode}) 284 c.logger.Debug("Response headers", logger.Data{"headers": resp.Header}) 285 286 if expectedStatusCode > 0 && resp.StatusCode != expectedStatusCode { 287 return nil, c.handleUnexpectedResponse(resp) 288 } 289 290 return resp, nil 291 } 292 293 func (c Client) stripHostPrefix(downloadLink string) string { 294 if strings.HasPrefix(downloadLink, apiVersion) { 295 return downloadLink 296 } 297 sp := strings.Split(downloadLink, apiVersion) 298 return sp[len(sp)-1] 299 } 300 301 func (c Client) handleUnexpectedResponse(resp *http.Response) error { 302 var pErr pivnetErr 303 304 b, err := ioutil.ReadAll(resp.Body) 305 if err != nil { 306 return err 307 } 308 309 if resp.StatusCode == http.StatusTooManyRequests { 310 return newErrTooManyRequests() 311 } 312 313 // We have to handle 500 differently because it has a different structure 314 if resp.StatusCode == http.StatusInternalServerError { 315 var internalServerError pivnetInternalServerErr 316 err = json.Unmarshal(b, &internalServerError) 317 if err != nil { 318 return err 319 } 320 321 pErr = pivnetErr{ 322 Message: internalServerError.Error, 323 } 324 } else { 325 err = json.Unmarshal(b, &pErr) 326 if err != nil { 327 return fmt.Errorf("could not parse json [%q] \n%s", b, err) 328 } 329 } 330 331 switch resp.StatusCode { 332 case http.StatusUnauthorized: 333 return newErrUnauthorized(pErr.Message) 334 case http.StatusNotFound: 335 return newErrNotFound(pErr.Message) 336 case http.StatusUnavailableForLegalReasons: 337 return newErrUnavailableForLegalReasons(pErr.Message) 338 default: 339 return ErrPivnetOther{ 340 ResponseCode: resp.StatusCode, 341 Message: pErr.Message, 342 Errors: pErr.Errors, 343 } 344 } 345 } 346 347 func isVersionsEndpoint(endpoint string) bool { 348 return endpoint == "/versions" 349 }