github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/flag/vulnerability_flags_test.go (about)

     1  package flag_test
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/spf13/viper"
     7  	"github.com/stretchr/testify/assert"
     8  	"go.uber.org/zap"
     9  	"go.uber.org/zap/zaptest/observer"
    10  
    11  	"github.com/devseccon/trivy/pkg/flag"
    12  	"github.com/devseccon/trivy/pkg/log"
    13  	"github.com/devseccon/trivy/pkg/types"
    14  )
    15  
    16  func TestVulnerabilityFlagGroup_ToOptions(t *testing.T) {
    17  	type fields struct {
    18  		vulnType string
    19  	}
    20  	tests := []struct {
    21  		name     string
    22  		args     []string
    23  		fields   fields
    24  		want     flag.VulnerabilityOptions
    25  		wantLogs []string
    26  	}{
    27  		{
    28  			name: "happy path for OS vulnerabilities",
    29  			args: []string{"alpine:latest"},
    30  			fields: fields{
    31  				vulnType: "os",
    32  			},
    33  			want: flag.VulnerabilityOptions{
    34  				VulnType: []string{types.VulnTypeOS},
    35  			},
    36  		},
    37  		{
    38  			name: "happy path for library vulnerabilities",
    39  			args: []string{"alpine:latest"},
    40  			fields: fields{
    41  				vulnType: "library",
    42  			},
    43  			want: flag.VulnerabilityOptions{
    44  				VulnType: []string{types.VulnTypeLibrary},
    45  			},
    46  		},
    47  	}
    48  
    49  	for _, tt := range tests {
    50  		t.Run(tt.name, func(t *testing.T) {
    51  			level := zap.WarnLevel
    52  
    53  			core, obs := observer.New(level)
    54  			log.Logger = zap.New(core).Sugar()
    55  
    56  			viper.Set(flag.VulnTypeFlag.ConfigName, tt.fields.vulnType)
    57  
    58  			// Assert options
    59  			f := &flag.VulnerabilityFlagGroup{
    60  				VulnType: &flag.VulnTypeFlag,
    61  			}
    62  
    63  			got := f.ToOptions()
    64  			assert.Equalf(t, tt.want, got, "ToOptions()")
    65  
    66  			// Assert log messages
    67  			var gotMessages []string
    68  			for _, entry := range obs.AllUntimed() {
    69  				gotMessages = append(gotMessages, entry.Message)
    70  			}
    71  			assert.Equal(t, tt.wantLogs, gotMessages, tt.name)
    72  		})
    73  
    74  	}
    75  }