github.com/gramework/gramework@v1.8.1-0.20231027140105-82555c9057f5/grypto/providers/scrypt/scrypt_provider.go (about)

     1  package scrypt
     2  
     3  import (
     4  	"crypto/subtle"
     5  	"fmt"
     6  
     7  	"github.com/gramework/gramework"
     8  	"github.com/gramework/gramework/grypto/internal/mcf"
     9  	"github.com/gramework/gramework/grypto/salt"
    10  	"golang.org/x/crypto/scrypt"
    11  )
    12  
    13  var log = gramework.Logger.WithField("package", "gramework/grypto/providers/scrypt")
    14  
    15  const (
    16  	DefaultN = 1 << 16
    17  	DefaultR = 10
    18  	DefaultP = 2
    19  
    20  	DefaultKeyLen  = 128
    21  	DefaultSaltLen = 64
    22  
    23  	prefix    = "$scrypt$"
    24  	prefixLen = len(prefix)
    25  
    26  	paramsFmt = "K=%d,N=%d,R=%d,P=%d,L=%d"
    27  
    28  	maxInt = int(^uint(0) >> 1)
    29  )
    30  
    31  var (
    32  	DefaultProvider = New()
    33  
    34  	providerName = []byte("scrypt")
    35  )
    36  
    37  // Provider handles internal state and algorythm parameters
    38  type Provider struct {
    39  	initialized bool
    40  	params      *scryptParams
    41  }
    42  
    43  type scryptParams struct {
    44  	keyLen  int
    45  	n, r, p int
    46  	saltLen int
    47  }
    48  
    49  // Equals returns true if params p equals to params p1
    50  func (p *scryptParams) Equals(p1 *scryptParams) bool {
    51  	return p.keyLen == p1.keyLen &&
    52  		p.n == p1.n &&
    53  		p.r == p1.r &&
    54  		p.p == p1.p &&
    55  		p.saltLen == p1.saltLen
    56  }
    57  
    58  // String returns params p as a MCF
    59  func (p *scryptParams) String() string {
    60  	return fmt.Sprintf(paramsFmt, p.keyLen, p.n, p.r, p.p, p.saltLen)
    61  }
    62  
    63  // New returns new scrypt provider
    64  func New() *Provider {
    65  	return &Provider{
    66  		params: &scryptParams{
    67  			keyLen:  DefaultKeyLen,
    68  			n:       DefaultN,
    69  			r:       DefaultR,
    70  			p:       DefaultP,
    71  			saltLen: DefaultSaltLen,
    72  		},
    73  	}
    74  }
    75  
    76  func (p *Provider) setDefaultIfNeeded() {
    77  	if !p.initialized {
    78  		p.initialized = true
    79  		if p.params == nil {
    80  			p.params = &scryptParams{
    81  				n: DefaultN,
    82  				r: DefaultR,
    83  				p: DefaultP,
    84  			}
    85  			return
    86  		}
    87  
    88  		if p.params.n <= 1 || p.params.n&(p.params.n-1) != 0 {
    89  			log.Warn("N must be > 1 and a power of 2, resetting to defaults")
    90  		}
    91  		if uint64(p.params.r)*uint64(p.params.p) >= 1<<30 || p.params.r > maxInt/128/p.params.p || p.params.r > maxInt/256 || p.params.n > maxInt/128/p.params.r {
    92  			log.Warn("parameters are too large, resettings to defaults")
    93  		}
    94  		return
    95  	}
    96  }
    97  
    98  // Hash returns scrypt hash of plaintext
    99  func (p *Provider) Hash(plaintext []byte) []byte {
   100  	p.setDefaultIfNeeded()
   101  	saltBytes := salt.Generate(p.params.saltLen)
   102  	key, _ := scrypt.Key(plaintext, saltBytes, p.params.n, p.params.r, p.params.p, p.params.keyLen)
   103  	return mcf.Encode(providerName, p.params.String(), saltBytes, key)
   104  }
   105  
   106  // HashString returns scrypt hash of plaintext
   107  func (p *Provider) HashString(plaintext string) []byte {
   108  	return p.Hash([]byte(plaintext))
   109  }
   110  
   111  // NeedsRehash checks if provided hash needs rehash
   112  func (p *Provider) NeedsRehash(hash []byte) bool {
   113  	if !prefixValid(hash) {
   114  		return true
   115  	}
   116  
   117  	mcfP := paramsFromMCF(hash[prefixLen:])
   118  	return !p.params.Equals(mcfP)
   119  }
   120  
   121  // Valid checks if provided plaintext is valid for given hash
   122  func (p *Provider) Valid(hash, plain []byte) bool {
   123  	if !prefixValid(hash) {
   124  		return false
   125  	}
   126  
   127  	_, params, saltBytes, expectedKey, err := mcf.Decode(hash, providerName)
   128  	if err != nil {
   129  		return false
   130  	}
   131  	hashParams := paramsFromMCF([]byte(params))
   132  
   133  	key, _ := scrypt.Key(plain, saltBytes, hashParams.n, hashParams.r, hashParams.p, hashParams.keyLen)
   134  
   135  	return subtle.ConstantTimeCompare(key, expectedKey) == 1
   136  }
   137  
   138  func prefixValid(hash []byte) bool {
   139  	return len(hash) > prefixLen || string(hash[:prefixLen]) == prefix
   140  }
   141  
   142  func paramsFromMCF(unprefixedHash []byte) *scryptParams {
   143  	// paramsFmt = "K=%d,N=%d,R=%d,P=%d,L=%d"
   144  	p := &scryptParams{}
   145  	fmt.Sscanf(string(unprefixedHash), paramsFmt, &p.keyLen, &p.n, &p.r, &p.p, &p.saltLen)
   146  	return p
   147  }