github.com/google/osv-scalibr@v0.4.1/guidedremediation/guidedremediation_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package guidedremediation_test
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/json"
    20  	"os"
    21  	"path/filepath"
    22  	"runtime"
    23  	"testing"
    24  
    25  	"github.com/google/go-cmp/cmp"
    26  	"github.com/google/go-cmp/cmp/cmpopts"
    27  	"github.com/google/osv-scalibr/clients/clienttest"
    28  	"github.com/google/osv-scalibr/clients/datasource"
    29  	"github.com/google/osv-scalibr/guidedremediation"
    30  	"github.com/google/osv-scalibr/guidedremediation/internal/vulnenrichertest"
    31  	"github.com/google/osv-scalibr/guidedremediation/options"
    32  	"github.com/google/osv-scalibr/guidedremediation/result"
    33  	"github.com/google/osv-scalibr/guidedremediation/strategy"
    34  	"github.com/google/osv-scalibr/guidedremediation/upgrade"
    35  )
    36  
    37  func TestFixOverride(t *testing.T) {
    38  	for _, tt := range []struct {
    39  		name              string
    40  		universeDir       string
    41  		manifest          string
    42  		wantManifestPath  string
    43  		wantResultPath    string
    44  		remOpts           options.RemediationOptions
    45  		maxUpgrades       int
    46  		noIntroduce       bool
    47  		noMavenNewDepMgmt bool
    48  	}{
    49  		{
    50  			name:             "basic",
    51  			universeDir:      "testdata/maven",
    52  			manifest:         "testdata/maven/basic/pom.xml",
    53  			wantManifestPath: "testdata/maven/basic/want.pom.xml",
    54  			wantResultPath:   "testdata/maven/basic/result.json",
    55  			remOpts:          options.DefaultRemediationOptions(),
    56  		},
    57  		{
    58  			name:             "patch choice",
    59  			universeDir:      "testdata/maven",
    60  			manifest:         "testdata/maven/patchchoice/pom.xml",
    61  			wantManifestPath: "testdata/maven/patchchoice/want.pom.xml",
    62  			wantResultPath:   "testdata/maven/patchchoice/result.json",
    63  			remOpts:          options.DefaultRemediationOptions(),
    64  		},
    65  		{
    66  			name:             "max upgrades",
    67  			universeDir:      "testdata/maven",
    68  			manifest:         "testdata/maven/patchchoice/pom.xml",
    69  			wantManifestPath: "testdata/maven/maxupgrades/want.pom.xml",
    70  			wantResultPath:   "testdata/maven/maxupgrades/result.json",
    71  			remOpts:          options.DefaultRemediationOptions(),
    72  			maxUpgrades:      2,
    73  		},
    74  		{
    75  			name:        "no introduce",
    76  			universeDir: "testdata/maven",
    77  			// Using same testdata as maxUpgrades because the result happens to be the same.
    78  			manifest:         "testdata/maven/patchchoice/pom.xml",
    79  			wantManifestPath: "testdata/maven/maxupgrades/want.pom.xml",
    80  			wantResultPath:   "testdata/maven/maxupgrades/result.json",
    81  			remOpts:          options.DefaultRemediationOptions(),
    82  			noIntroduce:      true,
    83  		},
    84  		{
    85  			name:              "no new dependency management",
    86  			universeDir:       "testdata/maven",
    87  			manifest:          "testdata/maven/patchchoice/pom.xml",
    88  			wantManifestPath:  "testdata/maven/nodepmgmt/want.pom.xml",
    89  			wantResultPath:    "testdata/maven/nodepmgmt/result.json",
    90  			remOpts:           options.DefaultRemediationOptions(),
    91  			noMavenNewDepMgmt: true,
    92  		},
    93  	} {
    94  		t.Run(tt.name, func(t *testing.T) {
    95  			// mavenClient is not used in the test, but is required to be non-nil for pom.xml manifests.
    96  			mavenClient, _ := datasource.NewDefaultMavenRegistryAPIClient(t.Context(), "")
    97  			client := clienttest.NewMockResolutionClient(t, filepath.Join(tt.universeDir, "universe.yaml"))
    98  			enricher := vulnenrichertest.NewMockVulnerabilityEnricher(t, filepath.Join(tt.universeDir, "vulnerabilities.json"))
    99  
   100  			tmpDir := t.TempDir()
   101  			manifestPath := filepath.Join(tmpDir, "pom.xml")
   102  			data, err := os.ReadFile(tt.manifest)
   103  			if err != nil {
   104  				t.Fatalf("failed reading manifest for copy: %v", err)
   105  			}
   106  			if err := os.WriteFile(manifestPath, data, 0644); err != nil {
   107  				t.Fatalf("failed copying manifest: %v", err)
   108  			}
   109  
   110  			opts := options.FixVulnsOptions{
   111  				Manifest:           manifestPath,
   112  				Strategy:           strategy.StrategyOverride,
   113  				VulnEnricher:       enricher,
   114  				ResolveClient:      client,
   115  				MavenClient:        mavenClient,
   116  				RemediationOptions: tt.remOpts,
   117  				MaxUpgrades:        tt.maxUpgrades,
   118  				NoIntroduce:        tt.noIntroduce,
   119  				NoMavenNewDepMgmt:  tt.noMavenNewDepMgmt,
   120  			}
   121  
   122  			gotRes, err := guidedremediation.FixVulns(opts)
   123  			if err != nil {
   124  				t.Fatalf("error fixing vulns: %v", err)
   125  			}
   126  			var wantRes result.Result
   127  			f, err := os.Open(tt.wantResultPath)
   128  			if err != nil {
   129  				t.Fatalf("failed opening result file: %v", err)
   130  			}
   131  			defer f.Close()
   132  			if err := json.NewDecoder(f).Decode(&wantRes); err != nil {
   133  				t.Fatalf("failed decoding result file: %v", err)
   134  			}
   135  			diffOpts := []cmp.Option{
   136  				cmpopts.IgnoreFields(result.Result{}, "Path"),
   137  				cmpopts.IgnoreFields(result.PackageUpdate{}, "Type"),
   138  			}
   139  			if diff := cmp.Diff(wantRes, gotRes, diffOpts...); diff != "" {
   140  				t.Errorf("FixVulns() result mismatch (-want +got):\n%s", diff)
   141  			}
   142  
   143  			wantManifest, err := os.ReadFile(tt.wantManifestPath)
   144  			if err != nil {
   145  				t.Fatalf("failed reading want manifest for comparison: %v", err)
   146  			}
   147  			gotManifest, err := os.ReadFile(manifestPath)
   148  			if err != nil {
   149  				t.Fatalf("failed reading got manifest for comparison: %v", err)
   150  			}
   151  			if runtime.GOOS == "windows" {
   152  				wantManifest = bytes.ReplaceAll(wantManifest, []byte("\r\n"), []byte("\n"))
   153  				gotManifest = bytes.ReplaceAll(gotManifest, []byte("\r\n"), []byte("\n"))
   154  			}
   155  
   156  			if diff := cmp.Diff(wantManifest, gotManifest); diff != "" {
   157  				t.Errorf("FixVulns() manifest mismatch (-want +got):\n%s", diff)
   158  			}
   159  		})
   160  	}
   161  }
   162  
   163  func TestFixRelax(t *testing.T) {
   164  	for _, tt := range []struct {
   165  		name             string
   166  		universeDir      string
   167  		manifest         string
   168  		lockfile         string
   169  		wantManifestPath string
   170  		wantResultPath   string
   171  		remOpts          options.RemediationOptions
   172  		maxUpgrades      int
   173  		noIntroduce      bool
   174  	}{
   175  		{
   176  			name:             "npm basic",
   177  			universeDir:      "testdata/npm",
   178  			manifest:         "testdata/npm/basicrelax/package.json",
   179  			lockfile:         "testdata/npm/basicrelax/package-lock.json",
   180  			wantManifestPath: "testdata/npm/basicrelax/want.package.json",
   181  			wantResultPath:   "testdata/npm/basicrelax/result.json",
   182  			remOpts:          options.DefaultRemediationOptions(),
   183  		},
   184  		{
   185  			name:             "python requirements",
   186  			universeDir:      "testdata/python",
   187  			manifest:         "testdata/python/relax/requirements/requirements.txt",
   188  			wantManifestPath: "testdata/python/relax/requirements/want.requirements.txt",
   189  			wantResultPath:   "testdata/python/relax/requirements/result.json",
   190  			remOpts:          options.DefaultRemediationOptions(),
   191  		},
   192  		{
   193  			name:             "python poetry",
   194  			universeDir:      "testdata/python",
   195  			manifest:         "testdata/python/relax/poetry/pyproject.toml",
   196  			wantManifestPath: "testdata/python/relax/poetry/want.pyproject.toml",
   197  			wantResultPath:   "testdata/python/relax/poetry/result.json",
   198  			remOpts:          options.DefaultRemediationOptions(),
   199  		},
   200  		{
   201  			name:             "python pipfile",
   202  			universeDir:      "testdata/python",
   203  			manifest:         "testdata/python/relax/pipfile/Pipfile",
   204  			wantManifestPath: "testdata/python/relax/pipfile/want.Pipfile",
   205  			wantResultPath:   "testdata/python/relax/pipfile/result.json",
   206  			remOpts:          options.DefaultRemediationOptions(),
   207  		},
   208  	} {
   209  		t.Run(tt.name, func(t *testing.T) {
   210  			client := clienttest.NewMockResolutionClient(t, filepath.Join(tt.universeDir, "universe.yaml"))
   211  			enricher := vulnenrichertest.NewMockVulnerabilityEnricher(t, filepath.Join(tt.universeDir, "vulnerabilities.json"))
   212  
   213  			tmpDir := t.TempDir()
   214  			manifestPath := filepath.Join(tmpDir, filepath.Base(tt.manifest))
   215  			data, err := os.ReadFile(tt.manifest)
   216  			if err != nil {
   217  				t.Fatalf("failed reading manifest for copy: %v", err)
   218  			}
   219  			if err := os.WriteFile(manifestPath, data, 0644); err != nil {
   220  				t.Fatalf("failed copying manifest: %v", err)
   221  			}
   222  
   223  			var lockfilePath string
   224  			if tt.lockfile != "" {
   225  				lockfilePath = filepath.Join(tmpDir, filepath.Base(tt.lockfile))
   226  				data, err := os.ReadFile(tt.lockfile)
   227  				if err != nil {
   228  					t.Fatalf("failed reading lockfile for copy: %v", err)
   229  				}
   230  				if err := os.WriteFile(lockfilePath, data, 0644); err != nil {
   231  					t.Fatalf("failed copying lockfile: %v", err)
   232  				}
   233  			}
   234  
   235  			opts := options.FixVulnsOptions{
   236  				Manifest:           manifestPath,
   237  				Lockfile:           lockfilePath,
   238  				Strategy:           strategy.StrategyRelax,
   239  				VulnEnricher:       enricher,
   240  				ResolveClient:      client,
   241  				RemediationOptions: tt.remOpts,
   242  				MaxUpgrades:        tt.maxUpgrades,
   243  				NoIntroduce:        tt.noIntroduce,
   244  			}
   245  
   246  			gotRes, err := guidedremediation.FixVulns(opts)
   247  			if err != nil {
   248  				t.Fatalf("error fixing vulns: %v", err)
   249  			}
   250  			var wantRes result.Result
   251  			f, err := os.Open(tt.wantResultPath)
   252  			if err != nil {
   253  				t.Fatalf("failed opening result file: %v", err)
   254  			}
   255  			defer f.Close()
   256  			if err := json.NewDecoder(f).Decode(&wantRes); err != nil {
   257  				t.Fatalf("failed decoding result file: %v", err)
   258  			}
   259  			diffOpts := []cmp.Option{
   260  				cmpopts.IgnoreFields(result.Result{}, "Path"),
   261  				cmpopts.IgnoreFields(result.PackageUpdate{}, "Type"),
   262  			}
   263  			if diff := cmp.Diff(wantRes, gotRes, diffOpts...); diff != "" {
   264  				t.Errorf("FixVulns() result mismatch (-want +got):\n%s", diff)
   265  			}
   266  
   267  			wantManifest, err := os.ReadFile(tt.wantManifestPath)
   268  			if err != nil {
   269  				t.Fatalf("failed reading want manifest for comparison: %v", err)
   270  			}
   271  			gotManifest, err := os.ReadFile(manifestPath)
   272  			if err != nil {
   273  				t.Fatalf("failed reading got manifest for comparison: %v", err)
   274  			}
   275  
   276  			if diff := cmp.Diff(wantManifest, gotManifest); diff != "" {
   277  				t.Errorf("FixVulns() manifest mismatch (-want +got):\n%s", diff)
   278  			}
   279  		})
   280  	}
   281  }
   282  
   283  func TestFixInPlace(t *testing.T) {
   284  	// Set up a test registry, since the lockfile writer needs to talk to the registry to get the package metadata.
   285  	srv := clienttest.NewMockHTTPServer(t)
   286  	srv.SetResponse(t, "/baz/1.0.1", []byte(`{
   287  	"name": "baz",
   288  	"version": "1.0.1",
   289  	"dist": {
   290  		"integrity": "sha512-aaaaaaaaaaaa",
   291  		"tarball": "https://registry.npmjs.org/baz/-/baz-1.0.1.tgz"
   292  	}
   293  }
   294  `))
   295  	srv.SetResponse(t, "/baz/2.0.1", []byte(`{
   296  	"name": "baz",
   297  	"version": "2.0.1",
   298  	"dist": {
   299  		"integrity": "sha512-bbbbbbbbbbbb",
   300  		"tarball": "https://registry.npmjs.org/baz/-/baz-2.0.1.tgz"
   301  	}
   302  }
   303  `))
   304  	for _, tt := range []struct {
   305  		name             string
   306  		universeDir      string
   307  		lockfile         string
   308  		wantLockfilePath string
   309  		wantResultPath   string
   310  		remOpts          options.RemediationOptions
   311  		maxUpgrades      int
   312  		noIntroduce      bool
   313  	}{
   314  		{
   315  			name:             "basic",
   316  			universeDir:      "testdata/npm",
   317  			lockfile:         "testdata/npm/basic/package-lock.json",
   318  			wantLockfilePath: "testdata/npm/basic/want.package-lock.json",
   319  			wantResultPath:   "testdata/npm/basic/result.json",
   320  			remOpts:          options.DefaultRemediationOptions(),
   321  		},
   322  	} {
   323  		t.Run(tt.name, func(t *testing.T) {
   324  			client := clienttest.NewMockResolutionClient(t, filepath.Join(tt.universeDir, "universe.yaml"))
   325  			enricher := vulnenrichertest.NewMockVulnerabilityEnricher(t, filepath.Join(tt.universeDir, "vulnerabilities.json"))
   326  
   327  			tmpDir := t.TempDir()
   328  			lockfilePath := filepath.Join(tmpDir, "package-lock.json")
   329  			data, err := os.ReadFile(tt.lockfile)
   330  			if err != nil {
   331  				t.Fatalf("failed reading lockfile for copy: %v", err)
   332  			}
   333  			if err := os.WriteFile(lockfilePath, data, 0644); err != nil {
   334  				t.Fatalf("failed copying lockfile: %v", err)
   335  			}
   336  
   337  			// make a npmrc to talk to test registry
   338  			if err := os.WriteFile(filepath.Join(tmpDir, ".npmrc"), []byte("registry="+srv.URL+"\n"), 0644); err != nil {
   339  				t.Fatalf("failed creating npmrc: %v", err)
   340  			}
   341  
   342  			opts := options.FixVulnsOptions{
   343  				Lockfile:           lockfilePath,
   344  				Strategy:           strategy.StrategyInPlace,
   345  				VulnEnricher:       enricher,
   346  				ResolveClient:      client,
   347  				RemediationOptions: tt.remOpts,
   348  				MaxUpgrades:        tt.maxUpgrades,
   349  				NoIntroduce:        tt.noIntroduce,
   350  			}
   351  
   352  			gotRes, err := guidedremediation.FixVulns(opts)
   353  			if err != nil {
   354  				t.Fatalf("error fixing vulns: %v", err)
   355  			}
   356  			var wantRes result.Result
   357  			f, err := os.Open(tt.wantResultPath)
   358  			if err != nil {
   359  				t.Fatalf("failed opening result file: %v", err)
   360  			}
   361  			defer f.Close()
   362  			if err := json.NewDecoder(f).Decode(&wantRes); err != nil {
   363  				t.Fatalf("failed decoding result file: %v", err)
   364  			}
   365  			diffOpts := []cmp.Option{
   366  				cmpopts.IgnoreFields(result.Result{}, "Path"),
   367  				cmpopts.IgnoreFields(result.PackageUpdate{}, "Type"),
   368  			}
   369  			if diff := cmp.Diff(wantRes, gotRes, diffOpts...); diff != "" {
   370  				t.Errorf("FixVulns() result mismatch (-want +got):\n%s", diff)
   371  			}
   372  
   373  			wantLockfile, err := os.ReadFile(tt.wantLockfilePath)
   374  			if err != nil {
   375  				t.Fatalf("failed reading want lockfile for comparison: %v", err)
   376  			}
   377  			gotLockfile, err := os.ReadFile(lockfilePath)
   378  			if err != nil {
   379  				t.Fatalf("failed reading got lockfile for comparison: %v", err)
   380  			}
   381  
   382  			if diff := cmp.Diff(wantLockfile, gotLockfile); diff != "" {
   383  				t.Errorf("FixVulns() lockfile mismatch (-want +got):\n%s", diff)
   384  			}
   385  		})
   386  	}
   387  }
   388  
   389  func TestUpdate(t *testing.T) {
   390  	for _, tt := range []struct {
   391  		name             string
   392  		universeDir      string
   393  		manifest         string
   394  		parentManifest   string
   395  		wantManifestPath string
   396  		wantResultPath   string
   397  		config           upgrade.Config
   398  		ignoreDev        bool
   399  	}{
   400  		{
   401  			name:             "basic",
   402  			universeDir:      "testdata/maven",
   403  			manifest:         "testdata/maven/update/pom.xml",
   404  			parentManifest:   "testdata/maven/update/parent.xml",
   405  			wantManifestPath: "testdata/maven/update/want.basic.pom.xml",
   406  			wantResultPath:   "testdata/maven/update/want.basic.json",
   407  		},
   408  		{
   409  			name:             "upgrade config",
   410  			universeDir:      "testdata/maven",
   411  			manifest:         "testdata/maven/update/pom.xml",
   412  			parentManifest:   "testdata/maven/update/parent.xml",
   413  			wantManifestPath: "testdata/maven/update/want.config.pom.xml",
   414  			wantResultPath:   "testdata/maven/update/want.config.json",
   415  			config: upgrade.Config{
   416  				"pkg:e": upgrade.Minor,
   417  			},
   418  		},
   419  		{
   420  			name:             "ignore dev",
   421  			universeDir:      "testdata/maven",
   422  			manifest:         "testdata/maven/update/pom.xml",
   423  			parentManifest:   "testdata/maven/update/parent.xml",
   424  			wantManifestPath: "testdata/maven/update/want.dev.pom.xml",
   425  			wantResultPath:   "testdata/maven/update/want.dev.json",
   426  			ignoreDev:        true,
   427  		},
   428  	} {
   429  		t.Run(tt.name, func(t *testing.T) {
   430  			// mavenClient is not used in the test, but is required to be non-nil for pom.xml manifests.
   431  			mavenClient, _ := datasource.NewDefaultMavenRegistryAPIClient(t.Context(), "")
   432  			client := clienttest.NewMockResolutionClient(t, filepath.Join(tt.universeDir, "universe.yaml"))
   433  
   434  			tmpDir := t.TempDir()
   435  			manifestPath := filepath.Join(tmpDir, "pom.xml")
   436  			data, err := os.ReadFile(tt.manifest)
   437  			if err != nil {
   438  				t.Fatalf("failed reading manifest for copy: %v", err)
   439  			}
   440  			if err := os.WriteFile(manifestPath, data, 0644); err != nil {
   441  				t.Fatalf("failed copying manifest: %v", err)
   442  			}
   443  
   444  			parentPath := filepath.Join(tmpDir, "parent.xml")
   445  			data, err = os.ReadFile(tt.parentManifest)
   446  			if err != nil {
   447  				t.Fatalf("failed reading manifest for copy: %v", err)
   448  			}
   449  			if err = os.WriteFile(parentPath, data, 0644); err != nil {
   450  				t.Fatalf("failed copying manifest: %v", err)
   451  			}
   452  
   453  			opts := options.UpdateOptions{
   454  				Manifest:      manifestPath,
   455  				ResolveClient: client,
   456  				MavenClient:   mavenClient,
   457  				UpgradeConfig: tt.config,
   458  				IgnoreDev:     tt.ignoreDev,
   459  			}
   460  
   461  			gotRes, err := guidedremediation.Update(opts)
   462  			if err != nil {
   463  				t.Fatalf("failed to update: %v", err)
   464  			}
   465  
   466  			wantManifest, err := os.ReadFile(tt.wantManifestPath)
   467  			if err != nil {
   468  				t.Fatalf("failed reading want manifest for comparison: %v", err)
   469  			}
   470  			gotManifest, err := os.ReadFile(manifestPath)
   471  			if err != nil {
   472  				t.Fatalf("failed reading got manifest for comparison: %v", err)
   473  			}
   474  			if runtime.GOOS == "windows" {
   475  				wantManifest = bytes.ReplaceAll(wantManifest, []byte("\r\n"), []byte("\n"))
   476  				gotManifest = bytes.ReplaceAll(gotManifest, []byte("\r\n"), []byte("\n"))
   477  			}
   478  			if diff := cmp.Diff(wantManifest, gotManifest); diff != "" {
   479  				t.Errorf("Update() manifest mismatch (-want +got):\n%s", diff)
   480  			}
   481  
   482  			var wantRes result.Result
   483  			f, err := os.Open(tt.wantResultPath)
   484  			if err != nil {
   485  				t.Fatalf("failed opening result file: %v", err)
   486  			}
   487  			defer f.Close()
   488  			if err := json.NewDecoder(f).Decode(&wantRes); err != nil {
   489  				t.Fatalf("failed decoding result file: %v", err)
   490  			}
   491  			diffOpts := []cmp.Option{
   492  				cmpopts.IgnoreFields(result.Result{}, "Path"),
   493  				cmpopts.IgnoreFields(result.PackageUpdate{}, "Type"),
   494  			}
   495  			if diff := cmp.Diff(wantRes, gotRes, diffOpts...); diff != "" {
   496  				t.Errorf("Update() result mismatch (-want +got):\n%s", diff)
   497  			}
   498  		})
   499  	}
   500  }