github.com/Axway/agent-sdk@v1.1.101/pkg/util/util.go (about)

     1  package util
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rsa"
     6  	"crypto/sha256"
     7  	"crypto/x509"
     8  	"encoding/base64"
     9  	"encoding/json"
    10  	"encoding/pem"
    11  	"errors"
    12  	"flag"
    13  	"fmt"
    14  	"hash/fnv"
    15  	"io/fs"
    16  	"net"
    17  	"net/http"
    18  	"net/url"
    19  	"os"
    20  	"path/filepath"
    21  	"reflect"
    22  	"regexp"
    23  	"sort"
    24  	"strconv"
    25  	"strings"
    26  	"time"
    27  	"unicode"
    28  
    29  	"github.com/Axway/agent-sdk/pkg/util/log"
    30  	"github.com/golang-jwt/jwt"
    31  	"github.com/sirupsen/logrus"
    32  )
    33  
    34  const (
    35  	// AmplifyCentral amplify central
    36  	AmplifyCentral             = "Amplify Central"
    37  	CentralHealthCheckEndpoint = "central"
    38  )
    39  
    40  // ComputeHash - get the hash of the byte array sent in
    41  func ComputeHash(data interface{}) (uint64, error) {
    42  	dataB, err := json.Marshal(data)
    43  	if err != nil {
    44  		return 0, fmt.Errorf("could not marshal data to bytes")
    45  	}
    46  
    47  	h := fnv.New64a()
    48  	h.Write(dataB)
    49  	return h.Sum64(), nil
    50  }
    51  
    52  // MaskValue - mask sensitive information with * (asterisk).  Length of sensitiveData to match returning maskedValue
    53  func MaskValue(sensitiveData string) string {
    54  	var maskedValue string
    55  	for i := 0; i < len(sensitiveData); i++ {
    56  		maskedValue += "*"
    57  	}
    58  	return maskedValue
    59  }
    60  
    61  // PrintDataInterface - prints contents of the interface only if in debug mode
    62  func PrintDataInterface(data interface{}) {
    63  	if log.GetLevel() == logrus.DebugLevel {
    64  		PrettyPrint(data)
    65  	}
    66  }
    67  
    68  // PrettyPrint - print the contents of the obj
    69  func PrettyPrint(data interface{}) {
    70  	var p []byte
    71  	//    var err := error
    72  	p, err := json.MarshalIndent(data, "", "\t")
    73  	if err != nil {
    74  		fmt.Println(err)
    75  		return
    76  	}
    77  	fmt.Printf("%s \n", p)
    78  }
    79  
    80  // GetProxyURL - need to provide my own function (instead of http.ProxyURL()) to handle empty url. Returning nil
    81  // means "no proxy"
    82  func GetProxyURL(fixedURL *url.URL) func(*http.Request) (*url.URL, error) {
    83  	return func(*http.Request) (*url.URL, error) {
    84  		if fixedURL == nil || fixedURL.Host == "" {
    85  			return nil, nil
    86  		}
    87  		return fixedURL, nil
    88  	}
    89  }
    90  
    91  // GetURLHostName - return the host name of the passed in URL
    92  func GetURLHostName(urlString string) string {
    93  	host, err := url.Parse(urlString)
    94  	if err != nil {
    95  		fmt.Println(err)
    96  		return ""
    97  	}
    98  	return host.Hostname()
    99  }
   100  
   101  // ParsePort - parse port from URL
   102  func ParsePort(url *url.URL) int {
   103  	port := 0
   104  	if url == nil {
   105  		return port
   106  	}
   107  
   108  	if url.Port() == "" {
   109  		port, _ = net.LookupPort("tcp", url.Scheme)
   110  	} else {
   111  		port, _ = strconv.Atoi(url.Port())
   112  	}
   113  	return port
   114  }
   115  
   116  // ParseAddr - parse host:port from URL
   117  func ParseAddr(url *url.URL) string {
   118  	if url == nil {
   119  		return ""
   120  	}
   121  
   122  	host, port, err := net.SplitHostPort(url.Host)
   123  	if err != nil {
   124  		return fmt.Sprintf("%s:%d", url.Host, ParsePort(url))
   125  	}
   126  	return fmt.Sprintf("%s:%s", host, port)
   127  }
   128  
   129  // StringSliceContains - does the given string slice contain the specified string?
   130  func StringSliceContains(items []string, s string) bool {
   131  	for _, item := range items {
   132  		if item == s {
   133  			return true
   134  		}
   135  	}
   136  	return false
   137  }
   138  
   139  // RemoveDuplicateValuesFromStringSlice - remove duplicate values from a string slice
   140  func RemoveDuplicateValuesFromStringSlice(strSlice []string) []string {
   141  	keys := make(map[string]bool)
   142  	list := []string{}
   143  
   144  	// If the key(values of the slice) is not equal
   145  	// to the already present value in new slice (list)
   146  	// then we append it. else we jump on another element.
   147  	for _, entry := range strSlice {
   148  		if _, value := keys[entry]; !value {
   149  			keys[entry] = true
   150  			list = append(list, entry)
   151  		}
   152  	}
   153  	return list
   154  }
   155  
   156  // IsItemInSlice - Returns true if the given item is in the string slice, strSlice should be sorted
   157  func IsItemInSlice(strSlice []string, item string) bool {
   158  	if len(strSlice) == 0 {
   159  		return false
   160  	}
   161  	if len(strSlice) == 1 {
   162  		return strSlice[0] == item
   163  	}
   164  	midPoint := len(strSlice) / 2
   165  	if item == strSlice[midPoint] {
   166  		return true
   167  	}
   168  	if item < strSlice[midPoint] {
   169  		return IsItemInSlice(strSlice[:midPoint], item)
   170  	}
   171  	return IsItemInSlice(strSlice[midPoint:], item)
   172  }
   173  
   174  // ConvertTimeToMillis - convert to milliseconds
   175  func ConvertTimeToMillis(tm time.Time) int64 {
   176  	return tm.UnixNano() / 1e6
   177  }
   178  
   179  // IsNotTest determines if a test is running or not
   180  func IsNotTest() bool {
   181  	return flag.Lookup("test.v") == nil
   182  }
   183  
   184  // RemoveUnquotedSpaces - Remove all whitespace not between matching quotes
   185  func RemoveUnquotedSpaces(s string) (string, error) {
   186  	rs := make([]rune, 0, len(s))
   187  	const out = rune(0)
   188  	var quote rune = out
   189  	var escape = false
   190  	for _, r := range s {
   191  		if !escape {
   192  			if r == '`' || r == '"' || r == '\'' {
   193  				if quote == out {
   194  					// start unescaped quote
   195  					quote = r
   196  				} else if quote == r {
   197  					// end unescaped quote
   198  					quote = out
   199  				}
   200  			}
   201  		}
   202  		// backslash (\) is the escape character
   203  		// except when it is the second backslash of a pair
   204  		escape = !escape && r == '\\'
   205  		if quote != out || !unicode.IsSpace(r) {
   206  			// between matching unescaped quotes
   207  			// or not whitespace
   208  			rs = append(rs, r)
   209  		}
   210  	}
   211  	if quote != out {
   212  		err := fmt.Errorf("unmatched unescaped quote: %q", quote)
   213  		return "", err
   214  	}
   215  	return string(rs), nil
   216  }
   217  
   218  // CreateDirIfNotExist - Creates the directory with same permission as parent
   219  func CreateDirIfNotExist(dirPath string) {
   220  	_, err := os.Stat(dirPath)
   221  	if os.IsNotExist(err) {
   222  		dataInfo := getParentDirInfo(dirPath)
   223  		os.MkdirAll(dirPath, dataInfo.Mode().Perm())
   224  	}
   225  }
   226  
   227  func getParentDirInfo(dirPath string) fs.FileInfo {
   228  	parent := filepath.Dir(dirPath)
   229  	dataInfo, err := os.Stat(parent)
   230  	if os.IsNotExist(err) {
   231  		return getParentDirInfo(parent)
   232  	}
   233  	return dataInfo
   234  }
   235  
   236  // MergeMapStringInterface - merges the provided maps.
   237  // If duplicate keys are found across the maps, then the keys in map n will be overwritten in keys in map n+1
   238  func MergeMapStringInterface(m ...map[string]interface{}) map[string]interface{} {
   239  	attrs := make(map[string]interface{})
   240  
   241  	for _, item := range m {
   242  		for k, v := range item {
   243  			attrs[k] = v
   244  		}
   245  	}
   246  
   247  	return attrs
   248  }
   249  
   250  // MergeMapStringString - merges the provided maps.
   251  // If duplicate keys are found across the maps, then the keys in map n will be overwritten in keys in map n+1.
   252  func MergeMapStringString(m ...map[string]string) map[string]string {
   253  	attrs := make(map[string]string)
   254  
   255  	for _, item := range m {
   256  		for k, v := range item {
   257  			attrs[k] = v
   258  		}
   259  	}
   260  
   261  	return attrs
   262  }
   263  
   264  // CheckEmptyMapStringString creates a new empty map if the provided map is nil
   265  func CheckEmptyMapStringString(m map[string]string) map[string]string {
   266  	if m == nil {
   267  		return make(map[string]string)
   268  	}
   269  
   270  	return m
   271  }
   272  
   273  // MapStringStringToMapStringInterface converts a map[string]string to map[string]interface{}
   274  func MapStringStringToMapStringInterface(m map[string]string) map[string]interface{} {
   275  	newMap := make(map[string]interface{})
   276  
   277  	for k, v := range m {
   278  		newMap[k] = v
   279  	}
   280  	return newMap
   281  }
   282  
   283  // ToString converts an interface{} to a string
   284  func ToString(v interface{}) string {
   285  	if v == nil {
   286  		return ""
   287  	}
   288  	s, ok := v.(string)
   289  	if !ok {
   290  		return ""
   291  	}
   292  	return s
   293  }
   294  
   295  // IsNil checks a value, or a pointer for nil
   296  func IsNil(v interface{}) bool {
   297  	return v == nil || reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()
   298  }
   299  
   300  // MapStringInterfaceToStringString - convert map[string]interface{} to map[string]string given the item can be a string
   301  func MapStringInterfaceToStringString(data map[string]interface{}) map[string]string {
   302  	newData := make(map[string]string)
   303  
   304  	for k, v := range data {
   305  		newData[k] = ""
   306  		if v == nil {
   307  			continue
   308  		} else {
   309  			newData[k] = fmt.Sprintf("%+v", v)
   310  		}
   311  	}
   312  	return newData
   313  }
   314  
   315  // ConvertStringToUint -
   316  func ConvertStringToUint(val string) uint64 {
   317  	ret, _ := strconv.ParseUint(val, 10, 64)
   318  	return ret
   319  }
   320  
   321  // ConvertUnitToString -
   322  func ConvertUnitToString(val uint64) string {
   323  	return strconv.FormatUint(val, 10)
   324  }
   325  
   326  // ReadPrivateKeyFile - reads and parses the private key content
   327  func ReadPrivateKeyFile(privateKeyFile, passwordFile string) (*rsa.PrivateKey, error) {
   328  	keyBytes, err := os.ReadFile(privateKeyFile)
   329  	if err != nil {
   330  		return nil, err
   331  	}
   332  
   333  	// cleanup private key read bytes
   334  	defer func() {
   335  		for i := range keyBytes {
   336  			keyBytes[i] = 0
   337  		}
   338  	}()
   339  
   340  	if passwordFile != "" {
   341  		var passwordBuf []byte
   342  		var err error
   343  		// cleanup password bytes
   344  		defer func() {
   345  			for i := range passwordBuf {
   346  				passwordBuf[i] = 0
   347  			}
   348  		}()
   349  
   350  		passwordBuf, err = readPassword(passwordFile)
   351  		if err != nil {
   352  			return nil, err
   353  		}
   354  
   355  		if len(passwordBuf) > 0 {
   356  			key, err := parseRSAPrivateKeyFromPEMWithBytePassword(keyBytes, passwordBuf)
   357  			if err != nil {
   358  				return nil, err
   359  			}
   360  
   361  			return key, nil
   362  
   363  		}
   364  		log.Debug("password file empty, assuming unencrypted key")
   365  		return jwt.ParseRSAPrivateKeyFromPEM(keyBytes)
   366  	}
   367  
   368  	log.Debug("no password, assuming unencrypted key")
   369  	return jwt.ParseRSAPrivateKeyFromPEM(keyBytes)
   370  }
   371  
   372  func readPassword(passwordFile string) ([]byte, error) {
   373  	return os.ReadFile(passwordFile)
   374  }
   375  
   376  // ReadPublicKeyBytes - reads the public key bytes from file
   377  func ReadPublicKeyBytes(publicKeyFile string) ([]byte, error) {
   378  	keyBytes, err := os.ReadFile(publicKeyFile)
   379  	if err != nil {
   380  		return nil, err
   381  	}
   382  	return keyBytes, nil
   383  }
   384  
   385  // parseRSAPrivateKeyFromPEMWithBytePassword - tries to parse an rsa private key using password as bytes
   386  // inspired from jwt.ParseRSAPrivateKeyFromPEMWithPassword
   387  func parseRSAPrivateKeyFromPEMWithBytePassword(key []byte, password []byte) (*rsa.PrivateKey, error) {
   388  	var err error
   389  
   390  	// trim any spaces from the password
   391  	password = bytes.TrimSpace(password)
   392  
   393  	// Parse PEM block
   394  	var block *pem.Block
   395  	if block, _ = pem.Decode(key); block == nil {
   396  		return nil, fmt.Errorf("key must be pem encoded")
   397  	}
   398  
   399  	var parsedKey interface{}
   400  
   401  	var blockDecrypted []byte
   402  	if blockDecrypted, err = x509.DecryptPEMBlock(block, password); err != nil {
   403  		return nil, err
   404  	}
   405  
   406  	if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
   407  		if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
   408  			return nil, err
   409  		}
   410  	}
   411  
   412  	var pkey *rsa.PrivateKey
   413  	var ok bool
   414  	if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
   415  		return nil, fmt.Errorf("[apicauth] not a private key")
   416  	}
   417  
   418  	return pkey, nil
   419  }
   420  
   421  // ParsePublicKey - parses the public key content
   422  func ParsePublicKey(publicKey []byte) (*rsa.PublicKey, error) {
   423  	block, _ := pem.Decode(publicKey)
   424  	if block == nil {
   425  		return nil, fmt.Errorf("failed to decode public key")
   426  	}
   427  	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
   428  	if err != nil {
   429  		return nil, fmt.Errorf("failed to parse public key: %s", err)
   430  	}
   431  
   432  	p, ok := pub.(*rsa.PublicKey)
   433  	if !ok {
   434  		return nil, fmt.Errorf("expected public key type to be *rsa.PublicKey but received %T", pub)
   435  	}
   436  	return p, nil
   437  }
   438  
   439  // ParsePublicKeyDER - parse DER block from public key
   440  func ParsePublicKeyDER(publicKey []byte) ([]byte, error) {
   441  	if b64key, err := base64.StdEncoding.DecodeString(string(publicKey)); err == nil {
   442  		return b64key, nil
   443  	}
   444  
   445  	_, err := x509.ParsePKIXPublicKey(publicKey)
   446  	if err != nil {
   447  		pemBlock, _ := pem.Decode(publicKey)
   448  		if pemBlock == nil {
   449  			return nil, errors.New("data in key was not valid")
   450  		}
   451  		if pemBlock.Type != "PUBLIC KEY" {
   452  			return nil, errors.New("unsupported key type")
   453  		}
   454  		return pemBlock.Bytes, nil
   455  	}
   456  	return publicKey, nil
   457  }
   458  
   459  // ComputeKIDFromDER - compute key ID for public key
   460  func ComputeKIDFromDER(publicKey []byte) (kid string, err error) {
   461  	b64key, err := ParsePublicKeyDER(publicKey)
   462  	if err != nil {
   463  		return "", err
   464  	}
   465  	h := sha256.New() // create new hash with sha256 checksum
   466  	/* #nosec G104 */
   467  	if _, err := h.Write(b64key); err != nil { // add b64key to hash
   468  		return "", err
   469  	}
   470  	e := base64.StdEncoding.EncodeToString(h.Sum(nil)) // return string of base64 encoded hash
   471  	kid = strings.Split(e, "=")[0]
   472  	kid = strings.Replace(kid, "+", "-", -1)
   473  	kid = strings.Replace(kid, "/", "_", -1)
   474  	return
   475  }
   476  
   477  // GetStringFromMapInterface - returns the validated string for the map element
   478  func GetStringFromMapInterface(key string, data map[string]interface{}) string {
   479  	if e, ok := data[key]; ok && e != nil {
   480  		if value, ok := e.(string); ok {
   481  			return value
   482  		}
   483  	}
   484  	return ""
   485  }
   486  
   487  // GetStringArrayFromMapInterface - returns the validated string array for the map element
   488  func GetStringArrayFromMapInterface(key string, data map[string]interface{}) []string {
   489  	val := []string{}
   490  	if e, ok := data[key]; ok && e != nil {
   491  		if i, ok := e.([]interface{}); ok {
   492  			for _, u := range i {
   493  				if s, ok := u.(string); ok {
   494  					val = append(val, s)
   495  				}
   496  			}
   497  		}
   498  		if sa, ok := e.([]string); ok {
   499  			val = append(val, sa...)
   500  		}
   501  	}
   502  	return val
   503  }
   504  
   505  // ConvertToDomainNameCompliant - converts string to be domain name complaint
   506  func ConvertToDomainNameCompliant(str string) string {
   507  	// convert all letters to lower first
   508  	newName := strings.ToLower(str)
   509  
   510  	// parse name out. All valid parts must be '-', '.', a-z, or 0-9
   511  	re := regexp.MustCompile(`[-\.a-z0-9]*`)
   512  	matches := re.FindAllString(newName, -1)
   513  
   514  	// join all of the parts, separated with '-'. This in effect is substituting all illegal chars with a '-'
   515  	newName = strings.Join(matches, "-")
   516  
   517  	// The regex rule says that the name must not begin or end with a '-' or '.', so trim them off
   518  	newName = strings.TrimLeft(strings.TrimRight(newName, "-."), "-.")
   519  
   520  	// The regex rule also says that the name must not have a sequence of ".-", "-.", or "..", so replace them
   521  	r1 := strings.ReplaceAll(newName, "-.", "--")
   522  	r2 := strings.ReplaceAll(r1, ".-", "--")
   523  	return strings.ReplaceAll(r2, "..", "--")
   524  }
   525  
   526  func OrderStringsInMap[T any](input map[string]T) map[string]T {
   527  	keys := make([]string, 0, len(input))
   528  	for k := range input {
   529  		keys = append(keys, k)
   530  	}
   531  	sort.Strings(keys)
   532  
   533  	output := map[string]T{}
   534  	for _, k := range keys {
   535  		output[k] = input[k]
   536  	}
   537  	return output
   538  }
   539  
   540  func OrderedKeys[T any](input map[string]T) []string {
   541  	keys := make([]string, 0, len(input))
   542  	for k := range input {
   543  		keys = append(keys, k)
   544  	}
   545  	sort.Strings(keys)
   546  
   547  	return keys
   548  }
   549  
   550  func FormatUserAgent(agentType, version, sdkVersion, environmentName, agentName string, isDocker, isGRPC bool) string {
   551  	ua := ""
   552  	if agentType != "" && version != "" && sdkVersion != "" {
   553  		deploymentType := "binary"
   554  		if isDocker {
   555  			deploymentType = "docker"
   556  		}
   557  		ua = fmt.Sprintf("%s/%s SDK/%s %s %s %s", agentType, version, sdkVersion, environmentName, agentName, deploymentType)
   558  		if isGRPC {
   559  			ua = fmt.Sprintf("%s reactive", ua)
   560  		}
   561  	}
   562  	return ua
   563  }