zotregistry.dev/zot@v1.4.4-0.20240314164342-eec277e14d20/pkg/test/common/utils.go (about)

     1  package common
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math/rand"
     7  	"net/http"
     8  	"net/url"
     9  	"os"
    10  	"path"
    11  	"time"
    12  
    13  	"github.com/phayes/freeport"
    14  	"gopkg.in/resty.v1"
    15  )
    16  
    17  const (
    18  	BaseURL               = "http://127.0.0.1:%s"
    19  	BaseSecureURL         = "https://127.0.0.1:%s"
    20  	SleepTime             = 100 * time.Millisecond
    21  	AuthorizationAllRepos = "**"
    22  )
    23  
    24  type isser interface {
    25  	Is(string) bool
    26  }
    27  
    28  // Index returns the index of the first occurrence of name in s,
    29  // or -1 if not present.
    30  func Index[E isser](s []E, name string) int {
    31  	for i, v := range s {
    32  		if v.Is(name) {
    33  			return i
    34  		}
    35  	}
    36  
    37  	return -1
    38  }
    39  
    40  // Contains reports whether name is present in s.
    41  func Contains[E isser](s []E, name string) bool {
    42  	return Index(s, name) >= 0
    43  }
    44  
    45  func Location(baseURL string, resp *resty.Response) string {
    46  	// For some API responses, the Location header is set and is supposed to
    47  	// indicate an opaque value. However, it is not clear if this value is an
    48  	// absolute URL (https://server:port/v2/...) or just a path (/v2/...)
    49  	// zot implements the latter as per the spec, but some registries appear to
    50  	// return the former - this needs to be clarified
    51  	loc := resp.Header().Get("Location")
    52  
    53  	uloc, err := url.Parse(loc)
    54  	if err != nil {
    55  		return ""
    56  	}
    57  
    58  	path := uloc.Path
    59  
    60  	return baseURL + path
    61  }
    62  
    63  type Controller interface {
    64  	Init() error
    65  	Run() error
    66  	Shutdown()
    67  	GetPort() int
    68  }
    69  
    70  type ControllerManager struct {
    71  	controller Controller
    72  }
    73  
    74  func (cm *ControllerManager) RunServer() {
    75  	// Useful to be able to call in the same goroutine for testing purposes
    76  	if err := cm.controller.Run(); !errors.Is(err, http.ErrServerClosed) {
    77  		panic(err)
    78  	}
    79  }
    80  
    81  func (cm *ControllerManager) StartServer() {
    82  	if err := cm.controller.Init(); err != nil {
    83  		panic(err)
    84  	}
    85  
    86  	go func() {
    87  		cm.RunServer()
    88  	}()
    89  }
    90  
    91  func (cm *ControllerManager) StopServer() {
    92  	cm.controller.Shutdown()
    93  }
    94  
    95  func (cm *ControllerManager) WaitServerToBeReady(port string) {
    96  	url := GetBaseURL(port)
    97  	WaitTillServerReady(url)
    98  }
    99  
   100  func (cm *ControllerManager) StartAndWait(port string) {
   101  	cm.StartServer()
   102  
   103  	url := GetBaseURL(port)
   104  	WaitTillServerReady(url)
   105  }
   106  
   107  func NewControllerManager(controller Controller) ControllerManager {
   108  	cm := ControllerManager{
   109  		controller: controller,
   110  	}
   111  
   112  	return cm
   113  }
   114  
   115  func WaitTillServerReady(url string) {
   116  	for {
   117  		_, err := resty.R().Get(url)
   118  		if err == nil {
   119  			break
   120  		}
   121  
   122  		time.Sleep(SleepTime)
   123  	}
   124  }
   125  
   126  func WaitTillTrivyDBDownloadStarted(rootDir string) {
   127  	for {
   128  		if _, err := os.Stat(path.Join(rootDir, "_trivy", "db", "trivy.db")); err == nil {
   129  			break
   130  		}
   131  
   132  		time.Sleep(SleepTime)
   133  	}
   134  }
   135  
   136  func GetFreePort() string {
   137  	port, err := freeport.GetFreePort()
   138  	if err != nil {
   139  		panic(err)
   140  	}
   141  
   142  	return fmt.Sprint(port)
   143  }
   144  
   145  func GetBaseURL(port string) string {
   146  	return fmt.Sprintf(BaseURL, port)
   147  }
   148  
   149  func GetSecureBaseURL(port string) string {
   150  	return fmt.Sprintf(BaseSecureURL, port)
   151  }
   152  
   153  func CustomRedirectPolicy(noOfRedirect int) resty.RedirectPolicy {
   154  	return resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
   155  		if len(via) >= noOfRedirect {
   156  			return fmt.Errorf("stopped after %d redirects", noOfRedirect) //nolint: goerr113
   157  		}
   158  
   159  		for key, val := range via[len(via)-1].Header {
   160  			req.Header[key] = val
   161  		}
   162  
   163  		respCookies := req.Response.Cookies()
   164  		for _, cookie := range respCookies {
   165  			req.AddCookie(cookie)
   166  		}
   167  
   168  		return nil
   169  	})
   170  }
   171  
   172  // Generates a random string with length 10 from lower case & upper case characters and
   173  // a seed that can be logged in tests (if test fails, you can reconstruct random string).
   174  func GenerateRandomString() (string, int64) {
   175  	seed := time.Now().UnixNano()
   176  	//nolint: gosec
   177  	seededRand := rand.New(rand.NewSource(seed))
   178  	charset := "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
   179  
   180  	randomBytes := make([]byte, 10)
   181  	for i := range randomBytes {
   182  		randomBytes[i] = charset[seededRand.Intn(len(charset))]
   183  	}
   184  
   185  	return string(randomBytes), seed
   186  }
   187  
   188  // Generates a random string with length 10 from lower case characters and digits and
   189  // a seed that can be logged in tests (if test fails, you can reconstruct random string).
   190  func GenerateRandomName() (string, int64) {
   191  	seed := time.Now().UnixNano()
   192  	//nolint: gosec
   193  	seededRand := rand.New(rand.NewSource(seed))
   194  	charset := "abcdefghijklmnopqrstuvwxyz" + "0123456789"
   195  
   196  	randomBytes := make([]byte, 10)
   197  	for i := range randomBytes {
   198  		randomBytes[i] = charset[seededRand.Intn(len(charset))]
   199  	}
   200  
   201  	return string(randomBytes), seed
   202  }
   203  
   204  func AccumulateField[T any, R any](list []T, accFunc func(T) R) []R {
   205  	result := make([]R, 0, len(list))
   206  
   207  	for i := range list {
   208  		result = append(result, accFunc(list[i]))
   209  	}
   210  
   211  	return result
   212  }
   213  
   214  func ContainSameElements[T comparable](list1, list2 []T) bool {
   215  	if len(list1) != len(list2) {
   216  		return false
   217  	}
   218  
   219  	count1 := map[T]int{}
   220  	count2 := map[T]int{}
   221  
   222  	for i := range list1 {
   223  		count1[list1[i]]++
   224  		count2[list2[i]]++
   225  	}
   226  
   227  	for key := range count1 {
   228  		if count1[key] != count2[key] {
   229  			return false
   230  		}
   231  	}
   232  
   233  	return true
   234  }