github.com/khulnasoft-lab/tunnel-db@v0.0.0-20231117205118-74e1113bd007/pkg/db/db_test.go (about)

     1  package db_test
     2  
     3  import (
     4  	"io"
     5  	"os"
     6  	"path/filepath"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/khulnasoft-lab/tunnel-db/pkg/db"
    12  )
    13  
    14  func TestInit(t *testing.T) {
    15  	tests := []struct {
    16  		name   string
    17  		dbPath string
    18  	}{
    19  		{
    20  			name:   "normal db",
    21  			dbPath: "testdata/normal.db",
    22  		},
    23  		{
    24  			name:   "broken db",
    25  			dbPath: "testdata/broken.db",
    26  		},
    27  		{
    28  			name:   "no db",
    29  			dbPath: "",
    30  		},
    31  	}
    32  	for _, tt := range tests {
    33  		t.Run(tt.name, func(t *testing.T) {
    34  			tmpDir := t.TempDir()
    35  
    36  			if tt.dbPath != "" {
    37  				dbPath := db.Path(tmpDir)
    38  				dbDir := filepath.Dir(dbPath)
    39  				err := os.MkdirAll(dbDir, 0700)
    40  				require.NoError(t, err)
    41  
    42  				err = copy(dbPath, tt.dbPath)
    43  				require.NoError(t, err)
    44  			}
    45  
    46  			err := db.Init(tmpDir)
    47  			require.NoError(t, err)
    48  		})
    49  	}
    50  }
    51  
    52  func copy(dstPath, srcPath string) error {
    53  	src, err := os.Open(srcPath)
    54  	if err != nil {
    55  		return err
    56  	}
    57  
    58  	dst, err := os.Create(dstPath)
    59  	if err != nil {
    60  		return err
    61  	}
    62  
    63  	_, err = io.Copy(dst, src)
    64  	return err
    65  }