github.com/anchore/syft@v1.38.2/syft/pkg/cataloger/ai/package_test.go (about)

     1  package ai
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/require"
     7  
     8  	"github.com/anchore/syft/syft/file"
     9  	"github.com/anchore/syft/syft/pkg"
    10  	"github.com/anchore/syft/syft/pkg/cataloger/internal/pkgtest"
    11  )
    12  
    13  func TestNewGGUFPackage(t *testing.T) {
    14  	tests := []struct {
    15  		name     string
    16  		metadata *pkg.GGUFFileHeader
    17  		input    struct {
    18  			modelName string
    19  			version   string
    20  			license   string
    21  			locations []file.Location
    22  		}
    23  		expected pkg.Package
    24  	}{
    25  		{
    26  			name: "complete GGUF package with all fields",
    27  			input: struct {
    28  				modelName string
    29  				version   string
    30  				license   string
    31  				locations []file.Location
    32  			}{
    33  				modelName: "llama3-8b",
    34  				version:   "3.0",
    35  				license:   "Apache-2.0",
    36  				locations: []file.Location{file.NewLocation("/models/llama3-8b.gguf")},
    37  			},
    38  			metadata: &pkg.GGUFFileHeader{
    39  				Architecture: "llama",
    40  				Quantization: "Q4_K_M",
    41  				Parameters:   8030000000,
    42  				GGUFVersion:  3,
    43  				TensorCount:  291,
    44  				RemainingKeyValues: map[string]any{
    45  					"general.random_kv": "foobar",
    46  				},
    47  			},
    48  			expected: pkg.Package{
    49  				Name:    "llama3-8b",
    50  				Version: "3.0",
    51  				Type:    pkg.ModelPkg,
    52  				Licenses: pkg.NewLicenseSet(
    53  					pkg.NewLicenseFromFields("Apache-2.0", "", nil),
    54  				),
    55  				Metadata: pkg.GGUFFileHeader{
    56  					Architecture: "llama",
    57  					Quantization: "Q4_K_M",
    58  					Parameters:   8030000000,
    59  					GGUFVersion:  3,
    60  					TensorCount:  291,
    61  					RemainingKeyValues: map[string]any{
    62  						"general.random_kv": "foobar",
    63  					},
    64  				},
    65  				Locations: file.NewLocationSet(file.NewLocation("/models/llama3-8b.gguf")),
    66  			},
    67  		},
    68  		{
    69  			name: "minimal GGUF package",
    70  			input: struct {
    71  				modelName string
    72  				version   string
    73  				license   string
    74  				locations []file.Location
    75  			}{
    76  				modelName: "gpt2-small",
    77  				version:   "1.0",
    78  				license:   "MIT",
    79  				locations: []file.Location{file.NewLocation("/models/simple.gguf")},
    80  			},
    81  			metadata: &pkg.GGUFFileHeader{
    82  				Architecture: "gpt2",
    83  				GGUFVersion:  3,
    84  				TensorCount:  50,
    85  			},
    86  			expected: pkg.Package{
    87  				Name:    "gpt2-small",
    88  				Version: "1.0",
    89  				Type:    pkg.ModelPkg,
    90  				Licenses: pkg.NewLicenseSet(
    91  					pkg.NewLicenseFromFields("MIT", "", nil),
    92  				),
    93  				Metadata: pkg.GGUFFileHeader{
    94  					Architecture: "gpt2",
    95  					GGUFVersion:  3,
    96  					TensorCount:  50,
    97  				},
    98  				Locations: file.NewLocationSet(file.NewLocation("/models/simple.gguf")),
    99  			},
   100  		},
   101  	}
   102  
   103  	for _, tt := range tests {
   104  		t.Run(tt.name, func(t *testing.T) {
   105  			actual := newGGUFPackage(
   106  				tt.metadata,
   107  				tt.input.modelName,
   108  				tt.input.version,
   109  				tt.input.license,
   110  				tt.input.locations...,
   111  			)
   112  
   113  			// Verify metadata type
   114  			_, ok := actual.Metadata.(pkg.GGUFFileHeader)
   115  			require.True(t, ok, "metadata should be GGUFFileHeader")
   116  
   117  			// Use AssertPackagesEqual for comprehensive comparison
   118  			pkgtest.AssertPackagesEqual(t, tt.expected, actual)
   119  		})
   120  	}
   121  }