github.com/0chain/gosdk@v1.17.11/zboxcore/zboxutil/util.go (about)

     1  package zboxutil
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"math"
    12  	"math/bits"
    13  	"net/http"
    14  	"path/filepath"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"errors"
    19  
    20  	thrown "github.com/0chain/errors"
    21  	"github.com/0chain/gosdk/zboxcore/allocationchange"
    22  	"github.com/0chain/gosdk/zboxcore/blockchain"
    23  	"github.com/h2non/filetype"
    24  	"github.com/hitenjain14/fasthttp"
    25  	"github.com/lithammer/shortuuid/v3"
    26  	"github.com/minio/sha256-simd"
    27  	"github.com/valyala/bytebufferpool"
    28  	"golang.org/x/crypto/chacha20poly1305"
    29  	"golang.org/x/crypto/scrypt"
    30  )
    31  
    32  const EncryptedFolderName = "encrypted"
    33  
    34  var BufferPool bytebufferpool.Pool
    35  
    36  type lazybuf struct {
    37  	path       string
    38  	buf        []byte
    39  	w          int
    40  	volAndPath string
    41  	volLen     int
    42  }
    43  
    44  func (b *lazybuf) index(i int) byte {
    45  	if b.buf != nil {
    46  		return b.buf[i]
    47  	}
    48  	return b.path[i]
    49  }
    50  
    51  func (b *lazybuf) append(c byte) {
    52  	if b.buf == nil {
    53  		if b.w < len(b.path) && b.path[b.w] == c {
    54  			b.w++
    55  			return
    56  		}
    57  		b.buf = make([]byte, len(b.path))
    58  		copy(b.buf, b.path[:b.w])
    59  	}
    60  	b.buf[b.w] = c
    61  	b.w++
    62  }
    63  
    64  func (b *lazybuf) string() string {
    65  	if b.buf == nil {
    66  		return b.volAndPath[:b.volLen+b.w]
    67  	}
    68  	return b.volAndPath[:b.volLen] + string(b.buf[:b.w])
    69  }
    70  
    71  // GetFileContentType returns the content type of the file based on reading the first 10KB of the file
    72  //   - ext is the extension of the file, shouldn't be empty
    73  //   - out is the file content
    74  func GetFileContentType(ext string, out io.ReadSeeker) (string, error) {
    75  
    76  	if ext != "" {
    77  		if content, ok := mimeDB[strings.TrimPrefix(ext, ".")]; ok {
    78  			return content.ContentType, nil
    79  		}
    80  	}
    81  
    82  	buffer := make([]byte, 10240)
    83  	n, err := out.Read(buffer)
    84  	defer out.Seek(0, 0) //nolint
    85  
    86  	if err != nil && err != io.EOF {
    87  		return "", err
    88  	}
    89  	buffer = buffer[:n]
    90  
    91  	kind, _ := filetype.Match(buffer)
    92  	if kind == filetype.Unknown {
    93  		return "application/octet-stream", nil
    94  	}
    95  
    96  	return kind.MIME.Value, nil
    97  }
    98  
    99  // GetFullRemotePath returns the full remote path by combining the local path and remote path
   100  //   - localPath is the local path of the file
   101  //   - remotePath is the remote path of the file
   102  func GetFullRemotePath(localPath, remotePath string) string {
   103  	if remotePath == "" || strings.HasSuffix(remotePath, "/") {
   104  		remotePath = strings.TrimRight(remotePath, "/")
   105  		_, fileName := filepath.Split(localPath)
   106  		remotePath = fmt.Sprintf("%s/%s", remotePath, fileName)
   107  	}
   108  	return remotePath
   109  }
   110  
   111  // NewConnectionId generate new connection id.
   112  // Connection is used to track the upload/download progress and redeem the cost of the operation from the network.
   113  // It's in the short uuid format. Check here for more on this format: https://pkg.go.dev/github.com/lithammer/shortuuid/v3@v3.0.7
   114  func NewConnectionId() string {
   115  	return shortuuid.New()
   116  }
   117  
   118  // IsRemoteAbs returns true if the path is remote absolute path
   119  //   - path is the path to check
   120  func IsRemoteAbs(path string) bool {
   121  	return strings.HasPrefix(path, "/")
   122  }
   123  
   124  // RemoteClean returns the cleaned remote path
   125  //   - path is the path to clean
   126  func RemoteClean(path string) string {
   127  	originalPath := path
   128  	volLen := 0 //volumeNameLen(path)
   129  	path = path[volLen:]
   130  	if path == "" {
   131  		if volLen > 1 && originalPath[1] != ':' {
   132  			// should be UNC
   133  			return path //FromSlash(originalPath)
   134  		}
   135  		return originalPath + "."
   136  	}
   137  	rooted := path[0] == '/' //os.IsPathSeparator(path[0])
   138  	// Invariants:
   139  	//	reading from path; r is index of next byte to process.
   140  	//	writing to buf; w is index of next byte to write.
   141  	//	dotdot is index in buf where .. must stop, either because
   142  	//		it is the leading slash or it is a leading ../../.. prefix.
   143  	n := len(path)
   144  	out := lazybuf{path: path, volAndPath: originalPath, volLen: volLen}
   145  	r, dotdot := 0, 0
   146  	if rooted {
   147  		out.append('/') //(Separator)
   148  		r, dotdot = 1, 1
   149  	}
   150  	for r < n {
   151  		switch {
   152  		case path[r] == '/' || path[r] == '\\': //os.IsPathSeparator(path[r]):
   153  			// empty path element
   154  			r++
   155  		case path[r] == '.' && (r+1 == n || path[r+1] == '/'): //os.IsPathSeparator(path[r+1])):
   156  			// . element
   157  			r++
   158  		case path[r] == '.' && path[r+1] == '.' && (r+2 == n || path[r+2] == '/'): //os.IsPathSeparator(path[r+2])):
   159  			// .. element: remove to last separator
   160  			r += 2
   161  			switch {
   162  			case out.w > dotdot:
   163  				// can backtrack
   164  				out.w--
   165  				for out.w > dotdot && !((out.index(out.w)) == '/') { //!os.IsPathSeparator(out.index(out.w)) {
   166  					out.w--
   167  				}
   168  			case !rooted:
   169  				// cannot backtrack, but not rooted, so append .. element.
   170  				if out.w > 0 {
   171  					out.append('/') //Separator)
   172  				}
   173  				out.append('.')
   174  				out.append('.')
   175  				dotdot = out.w
   176  			}
   177  		default:
   178  			// real path element.
   179  			// add slash if needed
   180  			if rooted && out.w != 1 || !rooted && out.w != 0 {
   181  				out.append('/') //(Separator)
   182  			}
   183  			// copy element
   184  			for ; r < n && !(path[r] == '/' || path[r] == '\\'); r++ { //!os.IsPathSeparator(path[r]); r++ {
   185  				out.append(path[r])
   186  			}
   187  		}
   188  	}
   189  	// Turn empty string into "."
   190  	if out.w == 0 {
   191  		out.append('.')
   192  	}
   193  	return out.string() //(FromSlash(out.string())
   194  }
   195  
   196  func Encrypt(key, text []byte) ([]byte, error) {
   197  	block, err := aes.NewCipher(key)
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  	b := base64.StdEncoding.EncodeToString(text)
   202  	ciphertext := make([]byte, aes.BlockSize+len(b))
   203  	iv := ciphertext[:aes.BlockSize]
   204  	if _, err := io.ReadFull(rand.Reader, iv); err != nil {
   205  		return nil, err
   206  	}
   207  	cfb := cipher.NewCFBEncrypter(block, iv)
   208  	cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(b))
   209  	return ciphertext, nil
   210  }
   211  
   212  func Decrypt(key, text []byte) ([]byte, error) {
   213  	block, err := aes.NewCipher(key)
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  	if len(text) < aes.BlockSize {
   218  		return nil, errors.New("ciphertext too short")
   219  	}
   220  	iv := text[:aes.BlockSize]
   221  	text = text[aes.BlockSize:]
   222  	cfb := cipher.NewCFBDecrypter(block, iv)
   223  	cfb.XORKeyStream(text, text)
   224  	data, err := base64.StdEncoding.DecodeString(string(text))
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	return data, nil
   229  }
   230  
   231  func GetRefsHash(r []byte) string {
   232  	hash := sha256.New()
   233  	hash.Write(r)
   234  	var buf []byte
   235  	buf = hash.Sum(buf)
   236  	return string(buf)
   237  }
   238  
   239  func GetActiveBlobbers(dirMask uint32, blobbers []*blockchain.StorageNode) []*blockchain.StorageNode {
   240  	var c, pos int
   241  	var r []*blockchain.StorageNode
   242  	for i := dirMask; i != 0; i &= ^(1 << pos) {
   243  		pos = bits.TrailingZeros32(i)
   244  		r = append(r, blobbers[pos])
   245  		c++
   246  	}
   247  
   248  	return r
   249  }
   250  
   251  func GetRateLimitValue(r *http.Response) (int, error) {
   252  	rlStr := r.Header.Get("X-Rate-Limit-Limit")
   253  	durStr := r.Header.Get("X-Rate-Limit-Duration")
   254  
   255  	rl, err := strconv.ParseFloat(rlStr, 64)
   256  	if err != nil {
   257  		return 0, err
   258  	}
   259  
   260  	dur, err := strconv.ParseFloat(durStr, 64)
   261  	if err != nil {
   262  		return 0, err
   263  	}
   264  
   265  	return int(math.Ceil(rl / dur)), nil
   266  }
   267  
   268  func GetFastRateLimitValue(r *fasthttp.Response) (int, error) {
   269  	rlStr := r.Header.Peek("X-Rate-Limit-Limit")
   270  	durStr := r.Header.Peek("X-Rate-Limit-Duration")
   271  
   272  	rl, err := strconv.ParseFloat(string(rlStr), 64)
   273  	if err != nil {
   274  		return 0, err
   275  	}
   276  
   277  	dur, err := strconv.ParseFloat(string(durStr), 64)
   278  	if err != nil {
   279  		return 0, err
   280  	}
   281  
   282  	return int(math.Ceil(rl / dur)), nil
   283  }
   284  
   285  func MajorError(errors []error) error {
   286  	countError := make(map[error]int)
   287  	for _, value := range errors {
   288  		if value != nil {
   289  			countError[value] += 1
   290  		}
   291  	}
   292  	maxFreq := 0
   293  	var maxKey error
   294  	for key, value := range countError {
   295  		if value > maxFreq {
   296  			maxKey = key
   297  			maxFreq = value
   298  		}
   299  	}
   300  	return maxKey
   301  }
   302  
   303  const (
   304  	keySize      = 32
   305  	nonceSize    = 12
   306  	saltSize     = 32
   307  	tagSize      = 16
   308  	scryptN      = 32768
   309  	scryptR      = 8
   310  	scryptP      = 1
   311  	scryptKeyLen = 32
   312  )
   313  
   314  func ScryptEncrypt(key, text []byte) ([]byte, error) {
   315  	if len(key) == 0 {
   316  		return nil, errors.New("scrypt: key cannot be empty")
   317  	}
   318  	if len(text) == 0 {
   319  		return nil, errors.New("scrypt: plaintext cannot be empty")
   320  	}
   321  	salt := make([]byte, saltSize)
   322  	if _, err := rand.Read(salt); err != nil {
   323  		return nil, err
   324  	}
   325  
   326  	derivedKey, err := scrypt.Key(key, salt, scryptN, scryptR, scryptP, scryptKeyLen)
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  	nonce := make([]byte, nonceSize)
   331  	if _, err := rand.Read(nonce); err != nil {
   332  		return nil, err
   333  	}
   334  	aead, err := chacha20poly1305.New(derivedKey)
   335  	if err != nil {
   336  		return nil, err
   337  	}
   338  
   339  	ciphertext := aead.Seal(nil, nonce, text, nil)
   340  	ciphertext = append(salt, ciphertext...)
   341  	ciphertext = append(nonce, ciphertext...)
   342  
   343  	return ciphertext, nil
   344  }
   345  
   346  func ScryptDecrypt(key, ciphertext []byte) ([]byte, error) {
   347  	if len(ciphertext) < saltSize+nonceSize+tagSize {
   348  		return nil, errors.New("ciphertext too short")
   349  	}
   350  
   351  	nonce := ciphertext[:nonceSize]
   352  	salt := ciphertext[nonceSize : nonceSize+saltSize]
   353  	text := ciphertext[saltSize+nonceSize:]
   354  
   355  	derivedKey, err := scrypt.Key(key, salt, scryptN, scryptR, scryptP, scryptKeyLen)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  	aead, err := chacha20poly1305.New(derivedKey)
   360  	if err != nil {
   361  		return nil, err
   362  	}
   363  	plaintext, err := aead.Open(nil, nonce, text, nil)
   364  	if err != nil {
   365  		return nil, err
   366  	}
   367  
   368  	return plaintext, nil
   369  }
   370  
   371  // Returns the error message code, message should be strictly of the
   372  // format: ".... err: {"code" : <return_this>, ...}, ..."
   373  func GetErrorMessageCode(errorMsg string) (string, error) {
   374  	// find index of "err"
   375  	targetWord := `err:`
   376  	idx := strings.Index(errorMsg, targetWord)
   377  	if idx == -1 {
   378  		return "", thrown.New("invalid_params", "message doesn't contain `err` field")
   379  
   380  	}
   381  	var a = make(map[string]string)
   382  	if idx+5 >= len(errorMsg) {
   383  		return "", thrown.New("invalid_format", "err field is not proper json")
   384  	}
   385  	err := json.Unmarshal([]byte(errorMsg[idx+5:]), &a)
   386  	if err != nil {
   387  		return "", thrown.New("invalid_format", "err field is not proper json")
   388  	}
   389  	return a["code"], nil
   390  
   391  }
   392  
   393  // Returns transpose of 2-D slice
   394  // Example: Given matrix [[a, b], [c, d], [e, f]] returns [[a, c, e], [b, d, f]]
   395  func Transpose(matrix [][]allocationchange.AllocationChange) [][]allocationchange.AllocationChange {
   396  	rowLength := len(matrix)
   397  	if rowLength == 0 {
   398  		return matrix
   399  	}
   400  	columnLength := len(matrix[0])
   401  	transposedMatrix := make([][]allocationchange.AllocationChange, columnLength)
   402  	for i := range transposedMatrix {
   403  		transposedMatrix[i] = make([]allocationchange.AllocationChange, rowLength)
   404  	}
   405  	for i := 0; i < columnLength; i++ {
   406  		for j := 0; j < rowLength; j++ {
   407  			transposedMatrix[i][j] = matrix[j][i]
   408  		}
   409  	}
   410  	return transposedMatrix
   411  }