github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/auth/auth.go (about)

     1  package auth
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/url"
     9  	"sort"
    10  	"strings"
    11  	"time"
    12  
    13  	model "github.com/cloudreve/Cloudreve/v3/models"
    14  	"github.com/cloudreve/Cloudreve/v3/pkg/conf"
    15  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    16  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    17  )
    18  
    19  var (
    20  	ErrAuthFailed        = serializer.NewError(serializer.CodeInvalidSign, "invalid sign", nil)
    21  	ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
    22  	ErrExpiresMissing    = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
    23  	ErrExpired           = serializer.NewError(serializer.CodeSignExpired, "signature expired", nil)
    24  )
    25  
    26  const CrHeaderPrefix = "X-Cr-"
    27  
    28  // General 通用的认证接口
    29  var General Auth
    30  
    31  // Auth 鉴权认证
    32  type Auth interface {
    33  	// 对给定Body进行签名,expires为0表示永不过期
    34  	Sign(body string, expires int64) string
    35  	// 对给定Body和Sign进行检查
    36  	Check(body string, sign string) error
    37  }
    38  
    39  // SignRequest 对PUT\POST等复杂HTTP请求签名,只会对URI部分、
    40  // 请求正文、`X-Cr-`开头的header进行签名
    41  func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
    42  	// 处理有效期
    43  	if expires > 0 {
    44  		expires += time.Now().Unix()
    45  	}
    46  
    47  	// 生成签名
    48  	sign := instance.Sign(getSignContent(r), expires)
    49  
    50  	// 将签名加到请求Header中
    51  	r.Header["Authorization"] = []string{"Bearer " + sign}
    52  	return r
    53  }
    54  
    55  // CheckRequest 对复杂请求进行签名验证
    56  func CheckRequest(instance Auth, r *http.Request) error {
    57  	var (
    58  		sign []string
    59  		ok   bool
    60  	)
    61  	if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
    62  		return ErrAuthHeaderMissing
    63  	}
    64  	sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
    65  
    66  	return instance.Check(getSignContent(r), sign[0])
    67  }
    68  
    69  // getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果请求 path 为从机上传 API,
    70  // 则不对正文签名。返回待签名/验证的字符串
    71  func getSignContent(r *http.Request) (rawSignString string) {
    72  	// 读取所有body正文
    73  	var body = []byte{}
    74  	if !strings.Contains(r.URL.Path, "/api/v3/slave/upload/") {
    75  		if r.Body != nil {
    76  			body, _ = ioutil.ReadAll(r.Body)
    77  			_ = r.Body.Close()
    78  			r.Body = ioutil.NopCloser(bytes.NewReader(body))
    79  		}
    80  	}
    81  
    82  	// 决定要签名的header
    83  	var signedHeader []string
    84  	for k, _ := range r.Header {
    85  		if strings.HasPrefix(k, CrHeaderPrefix) && k != CrHeaderPrefix+"Filename" {
    86  			signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
    87  		}
    88  	}
    89  	sort.Strings(signedHeader)
    90  
    91  	// 读取所有待签名Header
    92  	rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
    93  
    94  	return rawSignString
    95  }
    96  
    97  // SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证
    98  func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) {
    99  	// 处理有效期
   100  	if expires != 0 {
   101  		expires += time.Now().Unix()
   102  	}
   103  
   104  	base, err := url.Parse(uri)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	// 生成签名
   110  	sign := instance.Sign(base.Path, expires)
   111  
   112  	// 将签名加到URI中
   113  	queries := base.Query()
   114  	queries.Set("sign", sign)
   115  	base.RawQuery = queries.Encode()
   116  
   117  	return base, nil
   118  }
   119  
   120  // CheckURI 对URI进行鉴权
   121  func CheckURI(instance Auth, url *url.URL) error {
   122  	//获取待验证的签名正文
   123  	queries := url.Query()
   124  	sign := queries.Get("sign")
   125  	queries.Del("sign")
   126  	url.RawQuery = queries.Encode()
   127  
   128  	return instance.Check(url.Path, sign)
   129  }
   130  
   131  // Init 初始化通用鉴权器
   132  func Init() {
   133  	var secretKey string
   134  	if conf.SystemConfig.Mode == "master" {
   135  		secretKey = model.GetSettingByName("secret_key")
   136  	} else {
   137  		secretKey = conf.SlaveConfig.Secret
   138  		if secretKey == "" {
   139  			util.Log().Panic("SlaveSecret is not set, please specify it in config file.")
   140  		}
   141  	}
   142  	General = HMACAuth{
   143  		SecretKey: []byte(secretKey),
   144  	}
   145  }