github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/auth/hmac_test.go (about)

     1  package auth
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/DATA-DOG/go-sqlmock"
    10  	model "github.com/cloudreve/Cloudreve/v3/models"
    11  	"github.com/cloudreve/Cloudreve/v3/pkg/conf"
    12  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    13  	"github.com/gin-gonic/gin"
    14  	"github.com/jinzhu/gorm"
    15  	"github.com/stretchr/testify/assert"
    16  )
    17  
    18  var mock sqlmock.Sqlmock
    19  
    20  func TestMain(m *testing.M) {
    21  	// 设置gin为测试模式
    22  	gin.SetMode(gin.TestMode)
    23  
    24  	// 初始化sqlmock
    25  	var db *sql.DB
    26  	var err error
    27  	db, mock, err = sqlmock.New()
    28  	if err != nil {
    29  		panic("An error was not expected when opening a stub database connection")
    30  	}
    31  
    32  	mockDB, _ := gorm.Open("mysql", db)
    33  	model.DB = mockDB
    34  	defer db.Close()
    35  
    36  	m.Run()
    37  }
    38  
    39  func TestHMACAuth_Sign(t *testing.T) {
    40  	asserts := assert.New(t)
    41  	auth := HMACAuth{
    42  		SecretKey: []byte(util.RandStringRunes(256)),
    43  	}
    44  
    45  	asserts.NotEmpty(auth.Sign("content", 0))
    46  }
    47  
    48  func TestHMACAuth_Check(t *testing.T) {
    49  	asserts := assert.New(t)
    50  	auth := HMACAuth{
    51  		SecretKey: []byte(util.RandStringRunes(256)),
    52  	}
    53  
    54  	// 正常,永不过期
    55  	{
    56  		sign := auth.Sign("content", 0)
    57  		asserts.NoError(auth.Check("content", sign))
    58  	}
    59  
    60  	// 过期
    61  	{
    62  		sign := auth.Sign("content", 1)
    63  		asserts.Error(auth.Check("content", sign))
    64  	}
    65  
    66  	// 签名格式错误
    67  	{
    68  		sign := auth.Sign("content", 1)
    69  		asserts.Error(auth.Check("content", sign+":"))
    70  	}
    71  
    72  	// 过期日期格式错误
    73  	{
    74  		asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed"))
    75  	}
    76  
    77  	// 签名有误
    78  	{
    79  		asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10)))
    80  	}
    81  }
    82  
    83  func TestInit(t *testing.T) {
    84  	asserts := assert.New(t)
    85  	mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312"))
    86  	Init()
    87  	asserts.NoError(mock.ExpectationsWereMet())
    88  
    89  	// slave模式
    90  	conf.SystemConfig.Mode = "slave"
    91  	asserts.Panics(func() {
    92  		Init()
    93  	})
    94  }