github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/fanal/image/registry/azure/azure.go (about) 1 package azure 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "os" 8 "strings" 9 10 "github.com/Azure/azure-sdk-for-go/profiles/preview/preview/containerregistry/runtime/containerregistry" 11 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" 12 "github.com/Azure/azure-sdk-for-go/sdk/azidentity" 13 "golang.org/x/xerrors" 14 15 "github.com/devseccon/trivy/pkg/fanal/types" 16 ) 17 18 type Registry struct { 19 domain string 20 } 21 22 const ( 23 azureURL = "azurecr.io" 24 scope = "https://management.azure.com/.default" 25 scheme = "https" 26 ) 27 28 func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) error { 29 if !strings.HasSuffix(domain, azureURL) { 30 return xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern) 31 } 32 r.domain = domain 33 return nil 34 } 35 36 func (r *Registry) GetCredential(ctx context.Context) (string, string, error) { 37 cred, err := azidentity.NewDefaultAzureCredential(nil) 38 if err != nil { 39 return "", "", xerrors.Errorf("unable to generate acr credential error: %w", err) 40 } 41 aadToken, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{scope}}) 42 if err != nil { 43 return "", "", xerrors.Errorf("unable to get an access token: %w", err) 44 } 45 rt, err := refreshToken(ctx, aadToken.Token, r.domain) 46 if err != nil { 47 return "", "", xerrors.Errorf("unable to refresh token: %w", err) 48 } 49 return "00000000-0000-0000-0000-000000000000", *rt.RefreshToken, err 50 } 51 52 func refreshToken(ctx context.Context, accessToken, domain string) (containerregistry.RefreshToken, error) { 53 tenantID := os.Getenv("AZURE_TENANT_ID") 54 if tenantID == "" { 55 return containerregistry.RefreshToken{}, errors.New("missing environment variable AZURE_TENANT_ID") 56 } 57 repoClient := containerregistry.NewRefreshTokensClient(fmt.Sprintf("%s://%s", scheme, domain)) 58 return repoClient.GetFromExchange(ctx, "access_token", domain, tenantID, "", accessToken) 59 }