github.com/anchore/syft@v1.38.2/syft/pkg/cataloger/ai/cataloger_test.go (about)

     1  package ai
     2  
     3  import (
     4  	"os"
     5  	"path/filepath"
     6  	"testing"
     7  
     8  	"github.com/anchore/syft/syft/artifact"
     9  	"github.com/anchore/syft/syft/pkg"
    10  	"github.com/anchore/syft/syft/pkg/cataloger/internal/pkgtest"
    11  )
    12  
    13  func TestGGUFCataloger_Globs(t *testing.T) {
    14  	tests := []struct {
    15  		name     string
    16  		fixture  string
    17  		expected []string
    18  	}{
    19  		{
    20  			name:    "obtain gguf files",
    21  			fixture: "test-fixtures/glob-paths",
    22  			expected: []string{
    23  				"models/model.gguf",
    24  			},
    25  		},
    26  	}
    27  
    28  	for _, test := range tests {
    29  		t.Run(test.name, func(t *testing.T) {
    30  			pkgtest.NewCatalogTester().
    31  				FromDirectory(t, test.fixture).
    32  				ExpectsResolverContentQueries(test.expected).
    33  				TestCataloger(t, NewGGUFCataloger())
    34  		})
    35  	}
    36  }
    37  
    38  func TestGGUFCataloger(t *testing.T) {
    39  	tests := []struct {
    40  		name                  string
    41  		setup                 func(t *testing.T) string
    42  		expectedPackages      []pkg.Package
    43  		expectedRelationships []artifact.Relationship
    44  	}{
    45  		{
    46  			name: "catalog single GGUF file",
    47  			setup: func(t *testing.T) string {
    48  				dir := t.TempDir()
    49  				data := newTestGGUFBuilder().
    50  					withVersion(3).
    51  					withStringKV("general.architecture", "llama").
    52  					withStringKV("general.name", "llama3-8b").
    53  					withStringKV("general.version", "3.0").
    54  					withStringKV("general.license", "Apache-2.0").
    55  					withStringKV("general.quantization", "Q4_K_M").
    56  					withUint64KV("general.parameter_count", 8030000000).
    57  					withStringKV("general.some_random_kv", "foobar").
    58  					build()
    59  
    60  				path := filepath.Join(dir, "llama3-8b.gguf")
    61  				os.WriteFile(path, data, 0644)
    62  				return dir
    63  			},
    64  			expectedPackages: []pkg.Package{
    65  				{
    66  					Name:    "llama3-8b",
    67  					Version: "3.0",
    68  					Type:    pkg.ModelPkg,
    69  					Licenses: pkg.NewLicenseSet(
    70  						pkg.NewLicenseFromFields("Apache-2.0", "", nil),
    71  					),
    72  					Metadata: pkg.GGUFFileHeader{
    73  						Architecture:          "llama",
    74  						Quantization:          "Unknown",
    75  						Parameters:            0,
    76  						GGUFVersion:           3,
    77  						TensorCount:           0,
    78  						MetadataKeyValuesHash: "6e3d368066455ce4",
    79  						RemainingKeyValues: map[string]interface{}{
    80  							"general.some_random_kv": "foobar",
    81  						},
    82  					},
    83  				},
    84  			},
    85  			expectedRelationships: nil,
    86  		},
    87  		{
    88  			name: "catalog GGUF file with minimal metadata",
    89  			setup: func(t *testing.T) string {
    90  				dir := t.TempDir()
    91  				data := newTestGGUFBuilder().
    92  					withVersion(3).
    93  					withStringKV("general.architecture", "gpt2").
    94  					withStringKV("general.name", "gpt2-small").
    95  					withStringKV("gpt2.context_length", "1024").
    96  					withUint32KV("gpt2.embedding_length", 768).
    97  					build()
    98  
    99  				path := filepath.Join(dir, "gpt2-small.gguf")
   100  				os.WriteFile(path, data, 0644)
   101  				return dir
   102  			},
   103  			expectedPackages: []pkg.Package{
   104  				{
   105  					Name:     "gpt2-small",
   106  					Version:  "",
   107  					Type:     pkg.ModelPkg,
   108  					Licenses: pkg.NewLicenseSet(),
   109  					Metadata: pkg.GGUFFileHeader{
   110  						Architecture:          "gpt2",
   111  						Quantization:          "Unknown",
   112  						Parameters:            0,
   113  						GGUFVersion:           3,
   114  						TensorCount:           0,
   115  						MetadataKeyValuesHash: "9dc6f23591062a27",
   116  						RemainingKeyValues: map[string]interface{}{
   117  							"gpt2.context_length":   "1024",
   118  							"gpt2.embedding_length": uint32(768),
   119  						},
   120  					},
   121  				},
   122  			},
   123  			expectedRelationships: nil,
   124  		},
   125  	}
   126  
   127  	for _, tt := range tests {
   128  		t.Run(tt.name, func(t *testing.T) {
   129  			fixtureDir := tt.setup(t)
   130  
   131  			// Use pkgtest to catalog and compare
   132  			pkgtest.NewCatalogTester().
   133  				FromDirectory(t, fixtureDir).
   134  				Expects(tt.expectedPackages, tt.expectedRelationships).
   135  				IgnoreLocationLayer().
   136  				IgnorePackageFields("FoundBy", "Locations").
   137  				TestCataloger(t, NewGGUFCataloger())
   138  		})
   139  	}
   140  }