github.com/khulnasoft-lab/tunnel-db@v0.0.0-20231117205118-74e1113bd007/pkg/db/db_test.go (about) 1 package db_test 2 3 import ( 4 "io" 5 "os" 6 "path/filepath" 7 "testing" 8 9 "github.com/stretchr/testify/require" 10 11 "github.com/khulnasoft-lab/tunnel-db/pkg/db" 12 ) 13 14 func TestInit(t *testing.T) { 15 tests := []struct { 16 name string 17 dbPath string 18 }{ 19 { 20 name: "normal db", 21 dbPath: "testdata/normal.db", 22 }, 23 { 24 name: "broken db", 25 dbPath: "testdata/broken.db", 26 }, 27 { 28 name: "no db", 29 dbPath: "", 30 }, 31 } 32 for _, tt := range tests { 33 t.Run(tt.name, func(t *testing.T) { 34 tmpDir := t.TempDir() 35 36 if tt.dbPath != "" { 37 dbPath := db.Path(tmpDir) 38 dbDir := filepath.Dir(dbPath) 39 err := os.MkdirAll(dbDir, 0700) 40 require.NoError(t, err) 41 42 err = copy(dbPath, tt.dbPath) 43 require.NoError(t, err) 44 } 45 46 err := db.Init(tmpDir) 47 require.NoError(t, err) 48 }) 49 } 50 } 51 52 func copy(dstPath, srcPath string) error { 53 src, err := os.Open(srcPath) 54 if err != nil { 55 return err 56 } 57 58 dst, err := os.Create(dstPath) 59 if err != nil { 60 return err 61 } 62 63 _, err = io.Copy(dst, src) 64 return err 65 }