zotregistry.dev/zot@v1.4.4-0.20240314164342-eec277e14d20/pkg/test/common/fs.go (about)

     1  package common
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/fs"
     9  	"os"
    10  	"path"
    11  	"path/filepath"
    12  	"strings"
    13  	"time"
    14  
    15  	"golang.org/x/crypto/bcrypt"
    16  )
    17  
    18  var ErrNoGoModFileFound = errors.New("no go.mod file found in parent directories")
    19  
    20  func GetProjectRootDir() (string, error) {
    21  	workDir, err := os.Getwd()
    22  	if err != nil {
    23  		return "", err
    24  	}
    25  
    26  	for {
    27  		goModPath := filepath.Join(workDir, "go.mod")
    28  
    29  		_, err := os.Stat(goModPath)
    30  		if err == nil {
    31  			return workDir, nil
    32  		}
    33  
    34  		if workDir == filepath.Dir(workDir) {
    35  			return "", ErrNoGoModFileFound
    36  		}
    37  
    38  		workDir = filepath.Dir(workDir)
    39  	}
    40  }
    41  
    42  func CopyFile(sourceFilePath, destFilePath string) error {
    43  	destFile, err := os.Create(destFilePath)
    44  	if err != nil {
    45  		return err
    46  	}
    47  	defer destFile.Close()
    48  
    49  	sourceFile, err := os.Open(sourceFilePath)
    50  	if err != nil {
    51  		return err
    52  	}
    53  	defer sourceFile.Close()
    54  
    55  	if _, err = io.Copy(destFile, sourceFile); err != nil {
    56  		return err
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func CopyFiles(sourceDir, destDir string) error {
    63  	sourceMeta, err := os.Stat(sourceDir)
    64  	if err != nil {
    65  		return fmt.Errorf("CopyFiles os.Stat failed: %w", err)
    66  	}
    67  
    68  	if err := os.MkdirAll(destDir, sourceMeta.Mode()); err != nil {
    69  		return fmt.Errorf("CopyFiles os.MkdirAll failed: %w", err)
    70  	}
    71  
    72  	files, err := os.ReadDir(sourceDir)
    73  	if err != nil {
    74  		return fmt.Errorf("CopyFiles os.ReadDir failed: %w", err)
    75  	}
    76  
    77  	for _, file := range files {
    78  		sourceFilePath := path.Join(sourceDir, file.Name())
    79  		destFilePath := path.Join(destDir, file.Name())
    80  
    81  		if file.IsDir() {
    82  			if strings.HasPrefix(file.Name(), "_") {
    83  				// Some tests create the trivy related folders under test/_trivy
    84  				continue
    85  			}
    86  
    87  			if err = CopyFiles(sourceFilePath, destFilePath); err != nil {
    88  				return err
    89  			}
    90  		} else {
    91  			sourceFile, err := os.Open(sourceFilePath)
    92  			if err != nil {
    93  				return fmt.Errorf("CopyFiles os.Open failed: %w", err)
    94  			}
    95  			defer sourceFile.Close()
    96  
    97  			destFile, err := os.Create(destFilePath)
    98  			if err != nil {
    99  				return fmt.Errorf("CopyFiles os.Create failed: %w", err)
   100  			}
   101  			defer destFile.Close()
   102  
   103  			if _, err = io.Copy(destFile, sourceFile); err != nil {
   104  				return fmt.Errorf("io.Copy failed: %w", err)
   105  			}
   106  		}
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  func CopyTestKeysAndCerts(destDir string) error {
   113  	files := []string{
   114  		"ca.crt", "ca.key", "client.cert", "client.csr",
   115  		"client.key", "server.cert", "server.csr", "server.key",
   116  	}
   117  
   118  	rootPath, err := GetProjectRootDir()
   119  	if err != nil {
   120  		return err
   121  	}
   122  
   123  	sourceDir := filepath.Join(rootPath, "test/data")
   124  
   125  	sourceMeta, err := os.Stat(sourceDir)
   126  	if err != nil {
   127  		return fmt.Errorf("CopyFiles os.Stat failed: %w", err)
   128  	}
   129  
   130  	if err := os.MkdirAll(destDir, sourceMeta.Mode()); err != nil {
   131  		return err
   132  	}
   133  
   134  	for _, file := range files {
   135  		err = CopyFile(filepath.Join(sourceDir, file), filepath.Join(destDir, file))
   136  		if err != nil {
   137  			return err
   138  		}
   139  	}
   140  
   141  	return nil
   142  }
   143  
   144  func WriteFileWithPermission(path string, data []byte, perm fs.FileMode, overwrite bool) error {
   145  	if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
   146  		return err
   147  	}
   148  	flag := os.O_WRONLY | os.O_CREATE
   149  
   150  	if overwrite {
   151  		flag |= os.O_TRUNC
   152  	} else {
   153  		flag |= os.O_EXCL
   154  	}
   155  
   156  	file, err := os.OpenFile(path, flag, perm)
   157  	if err != nil {
   158  		return err
   159  	}
   160  
   161  	_, err = file.Write(data)
   162  	if err != nil {
   163  		file.Close()
   164  
   165  		return err
   166  	}
   167  
   168  	return file.Close()
   169  }
   170  
   171  func ReadLogFileAndSearchString(logPath string, stringToMatch string, timeout time.Duration) (bool, error) {
   172  	ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
   173  	defer cancelFunc()
   174  
   175  	for {
   176  		select {
   177  		case <-ctx.Done():
   178  			return false, nil
   179  		default:
   180  			content, err := os.ReadFile(logPath)
   181  			if err != nil {
   182  				return false, err
   183  			}
   184  
   185  			if strings.Contains(string(content), stringToMatch) {
   186  				return true, nil
   187  			}
   188  		}
   189  	}
   190  }
   191  
   192  func ReadLogFileAndCountStringOccurence(logPath string, stringToMatch string,
   193  	timeout time.Duration, count int,
   194  ) (bool, error) {
   195  	ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
   196  	defer cancelFunc()
   197  
   198  	for {
   199  		select {
   200  		case <-ctx.Done():
   201  			return false, nil
   202  		default:
   203  			content, err := os.ReadFile(logPath)
   204  			if err != nil {
   205  				return false, err
   206  			}
   207  
   208  			if strings.Count(string(content), stringToMatch) >= count {
   209  				return true, nil
   210  			}
   211  		}
   212  	}
   213  }
   214  
   215  func GetCredString(username, password string) string {
   216  	hash, err := bcrypt.GenerateFromPassword([]byte(password), 10)
   217  	if err != nil {
   218  		panic(err)
   219  	}
   220  
   221  	usernameAndHash := fmt.Sprintf("%s:%s\n", username, string(hash))
   222  
   223  	return usernameAndHash
   224  }
   225  
   226  func MakeHtpasswdFileFromString(fileContent string) string {
   227  	htpasswdFile, err := os.CreateTemp("", "htpasswd-")
   228  	if err != nil {
   229  		panic(err)
   230  	}
   231  
   232  	content := []byte(fileContent)
   233  	if err := os.WriteFile(htpasswdFile.Name(), content, 0o600); err != nil { //nolint:gomnd
   234  		panic(err)
   235  	}
   236  
   237  	return htpasswdFile.Name()
   238  }