github.com/coyove/common@v0.0.0-20240403014525-f70e643f9de8/session/session.go (about)

     1  package session
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/sha1"
     8  	"encoding/binary"
     9  	"fmt"
    10  	"strconv"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/coyove/common/rand"
    15  )
    16  
    17  var repository struct {
    18  	sync.Mutex
    19  	rand      *rand.Rand
    20  	blk       cipher.Block
    21  	oldTokens map[[16]byte]bool
    22  }
    23  
    24  const (
    25  	TTL = 86400
    26  )
    27  
    28  func init() {
    29  	repository.rand = rand.New()
    30  	iv := repository.rand.Fetch(16)
    31  	repository.blk, _ = aes.NewCipher(iv)
    32  	repository.oldTokens = make(map[[16]byte]bool)
    33  }
    34  
    35  // New returns a token for the session
    36  func New(extra string) (tok [16]byte) {
    37  	ts := uint32(time.Now().Unix())
    38  	binary.LittleEndian.PutUint32(tok[:4], ts)
    39  	copy(tok[4:8], extra)
    40  	x := sha1.Sum(tok[:8])
    41  	copy(tok[8:], x[:])
    42  	repository.blk.Encrypt(tok[:], tok[:])
    43  
    44  	if repository.rand.Intn(1024) == 0 {
    45  		go func() {
    46  			repository.Lock()
    47  			now := uint32(time.Now().Unix())
    48  			for tok := range repository.oldTokens {
    49  				ts := binary.LittleEndian.Uint32(tok[:4])
    50  				if ts > now || now-ts > TTL {
    51  					delete(repository.oldTokens, tok)
    52  				}
    53  			}
    54  			repository.Unlock()
    55  		}()
    56  	}
    57  
    58  	return
    59  }
    60  
    61  // NewString returns a string token for the session
    62  func NewString(extra string) string {
    63  	return fmt.Sprintf("%x", New(extra))
    64  }
    65  
    66  // Consume validates the token and consumes it (if true)
    67  func Consume(tok [16]byte, extra string) bool {
    68  	repository.blk.Decrypt(tok[:], tok[:])
    69  	x := sha1.Sum(tok[:8])
    70  	if !bytes.Equal(x[:8], tok[8:]) {
    71  		return false
    72  	}
    73  
    74  	if string(tok[4:8]) != extra[:4] {
    75  		return false
    76  	}
    77  
    78  	now := uint32(time.Now().Unix())
    79  	ts := binary.LittleEndian.Uint32(tok[:4])
    80  
    81  	if now < ts {
    82  		return false
    83  	}
    84  	if now-ts > TTL {
    85  		return false
    86  	}
    87  
    88  	repository.Lock()
    89  	if repository.oldTokens[tok] {
    90  		repository.Unlock()
    91  		return false
    92  	}
    93  	repository.oldTokens[tok] = true
    94  	repository.Unlock()
    95  	return true
    96  }
    97  
    98  // ConsumeString validates the token and consumes it (if true)
    99  func ConsumeString(tok string, extra string) bool {
   100  	if len(tok) != 32 {
   101  		return false
   102  	}
   103  
   104  	var t [16]byte
   105  	for i := 0; i < 16; i++ {
   106  		n, err := strconv.ParseInt(tok[i*2:i*2+2], 16, 64)
   107  		if err != nil {
   108  			return false
   109  		}
   110  		t[i] = byte(n)
   111  	}
   112  
   113  	return Consume(t, extra)
   114  }