github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/service/admin/policy.go (about)

     1  package admin
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive"
     9  	"net/http"
    10  	"net/url"
    11  	"os"
    12  	"path/filepath"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	model "github.com/cloudreve/Cloudreve/v3/models"
    18  	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
    19  	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
    20  	"github.com/cloudreve/Cloudreve/v3/pkg/conf"
    21  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos"
    22  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
    23  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
    24  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3"
    25  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
    26  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    27  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    28  	"github.com/gin-gonic/gin"
    29  	cossdk "github.com/tencentyun/cos-go-sdk-v5"
    30  )
    31  
    32  // PathTestService 本地路径测试服务
    33  type PathTestService struct {
    34  	Path string `json:"path" binding:"required"`
    35  }
    36  
    37  // SlaveTestService 从机测试服务
    38  type SlaveTestService struct {
    39  	Secret string `json:"secret" binding:"required"`
    40  	Server string `json:"server" binding:"required"`
    41  }
    42  
    43  // SlavePingService 从机相应ping
    44  type SlavePingService struct {
    45  	Callback string `json:"callback" binding:"required"`
    46  }
    47  
    48  // AddPolicyService 存储策略添加服务
    49  type AddPolicyService struct {
    50  	Policy model.Policy `json:"policy" binding:"required"`
    51  }
    52  
    53  // PolicyService 存储策略ID服务
    54  type PolicyService struct {
    55  	ID     uint   `uri:"id" json:"id" binding:"required"`
    56  	Region string `json:"region"`
    57  }
    58  
    59  // Delete 删除存储策略
    60  func (service *PolicyService) Delete() serializer.Response {
    61  	// 禁止删除默认策略
    62  	if service.ID == 1 {
    63  		return serializer.Err(serializer.CodeDeleteDefaultPolicy, "", nil)
    64  	}
    65  
    66  	policy, err := model.GetPolicyByID(service.ID)
    67  	if err != nil {
    68  		return serializer.Err(serializer.CodePolicyNotExist, "", err)
    69  	}
    70  
    71  	// 检查是否有文件使用
    72  	total := 0
    73  	row := model.DB.Model(&model.File{}).Where("policy_id = ?", service.ID).
    74  		Select("count(id)").Row()
    75  	row.Scan(&total)
    76  	if total > 0 {
    77  		return serializer.Err(serializer.CodePolicyUsedByFiles, strconv.Itoa(total), nil)
    78  	}
    79  
    80  	// 检查用户组使用
    81  	var groups []model.Group
    82  	model.DB.Model(&model.Group{}).Where(
    83  		"policies like ?",
    84  		fmt.Sprintf("%%[%d]%%", service.ID),
    85  	).Find(&groups)
    86  
    87  	if len(groups) > 0 {
    88  		return serializer.Err(serializer.CodePolicyUsedByGroups, strconv.Itoa(len(groups)), nil)
    89  	}
    90  
    91  	model.DB.Delete(&policy)
    92  	policy.ClearCache()
    93  
    94  	return serializer.Response{}
    95  }
    96  
    97  // Get 获取存储策略详情
    98  func (service *PolicyService) Get() serializer.Response {
    99  	policy, err := model.GetPolicyByID(service.ID)
   100  	if err != nil {
   101  		return serializer.Err(serializer.CodePolicyNotExist, "", err)
   102  	}
   103  
   104  	return serializer.Response{Data: policy}
   105  }
   106  
   107  // GetOAuth 获取 OneDrive OAuth 地址
   108  func (service *PolicyService) GetOAuth(c *gin.Context, policyType string) serializer.Response {
   109  	policy, err := model.GetPolicyByID(service.ID)
   110  	if err != nil || policy.Type != policyType {
   111  		return serializer.Err(serializer.CodePolicyNotExist, "", nil)
   112  	}
   113  
   114  	util.SetSession(c, map[string]interface{}{
   115  		policyType + "_oauth_policy": policy.ID,
   116  	})
   117  
   118  	var redirect string
   119  	switch policy.Type {
   120  	case "onedrive":
   121  		client, err := onedrive.NewClient(&policy)
   122  		if err != nil {
   123  			return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err)
   124  		}
   125  
   126  		redirect = client.OAuthURL(context.Background(), []string{
   127  			"offline_access",
   128  			"files.readwrite.all",
   129  		})
   130  	case "googledrive":
   131  		client, err := googledrive.NewClient(&policy)
   132  		if err != nil {
   133  			return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err)
   134  		}
   135  
   136  		redirect = client.OAuthURL(context.Background(), googledrive.RequiredScope)
   137  	}
   138  
   139  	// Delete token cache
   140  	cache.Deletes([]string{policy.BucketName}, policyType+"_")
   141  
   142  	return serializer.Response{Data: redirect}
   143  }
   144  
   145  // AddSCF 创建回调云函数
   146  func (service *PolicyService) AddSCF() serializer.Response {
   147  	policy, err := model.GetPolicyByID(service.ID)
   148  	if err != nil {
   149  		return serializer.Err(serializer.CodePolicyNotExist, "", nil)
   150  	}
   151  
   152  	if err := cos.CreateSCF(&policy, service.Region); err != nil {
   153  		return serializer.ParamErr("Failed to create SCF function", err)
   154  	}
   155  
   156  	return serializer.Response{}
   157  }
   158  
   159  // AddCORS 创建跨域策略
   160  func (service *PolicyService) AddCORS() serializer.Response {
   161  	policy, err := model.GetPolicyByID(service.ID)
   162  	if err != nil {
   163  		return serializer.Err(serializer.CodePolicyNotExist, "", nil)
   164  	}
   165  
   166  	switch policy.Type {
   167  	case "oss":
   168  		handler, err := oss.NewDriver(&policy)
   169  		if err != nil {
   170  			return serializer.Err(serializer.CodeAddCORS, "", err)
   171  		}
   172  		if err := handler.CORS(); err != nil {
   173  			return serializer.Err(serializer.CodeAddCORS, "", err)
   174  		}
   175  	case "cos":
   176  		u, _ := url.Parse(policy.Server)
   177  		b := &cossdk.BaseURL{BucketURL: u}
   178  		handler := cos.Driver{
   179  			Policy:     &policy,
   180  			HTTPClient: request.NewClient(),
   181  			Client: cossdk.NewClient(b, &http.Client{
   182  				Transport: &cossdk.AuthorizationTransport{
   183  					SecretID:  policy.AccessKey,
   184  					SecretKey: policy.SecretKey,
   185  				},
   186  			}),
   187  		}
   188  
   189  		if err := handler.CORS(); err != nil {
   190  			return serializer.Err(serializer.CodeAddCORS, "", err)
   191  		}
   192  	case "s3":
   193  		handler, err := s3.NewDriver(&policy)
   194  		if err != nil {
   195  			return serializer.Err(serializer.CodeAddCORS, "", err)
   196  		}
   197  
   198  		if err := handler.CORS(); err != nil {
   199  			return serializer.Err(serializer.CodeAddCORS, "", err)
   200  		}
   201  	default:
   202  		return serializer.Err(serializer.CodePolicyNotAllowed, "", nil)
   203  	}
   204  
   205  	return serializer.Response{}
   206  }
   207  
   208  // Test 从机响应ping
   209  func (service *SlavePingService) Test() serializer.Response {
   210  	master, err := url.Parse(service.Callback)
   211  	if err != nil {
   212  		return serializer.ParamErr("Failed to parse Master site url: "+err.Error(), nil)
   213  	}
   214  
   215  	controller, _ := url.Parse("/api/v3/site/ping")
   216  
   217  	r := request.NewClient()
   218  	res, err := r.Request(
   219  		"GET",
   220  		master.ResolveReference(controller).String(),
   221  		nil,
   222  		request.WithTimeout(time.Duration(10)*time.Second),
   223  	).DecodeResponse()
   224  
   225  	if err != nil {
   226  		return serializer.Err(serializer.CodeSlavePingMaster, err.Error(), nil)
   227  	}
   228  
   229  	version := conf.BackendVersion
   230  	if conf.IsPro == "true" {
   231  		version += "-pro"
   232  	}
   233  	if res.Data.(string) != version {
   234  		return serializer.Err(serializer.CodeVersionMismatch, "Master: "+res.Data.(string)+", Slave: "+version, nil)
   235  	}
   236  
   237  	return serializer.Response{}
   238  }
   239  
   240  // Test 测试从机通信
   241  func (service *SlaveTestService) Test() serializer.Response {
   242  	slave, err := url.Parse(service.Server)
   243  	if err != nil {
   244  		return serializer.ParamErr("Failed to parse slave node server URL: "+err.Error(), nil)
   245  	}
   246  
   247  	controller, _ := url.Parse("/api/v3/slave/ping")
   248  
   249  	// 请求正文
   250  	body := map[string]string{
   251  		"callback": model.GetSiteURL().String(),
   252  	}
   253  	bodyByte, _ := json.Marshal(body)
   254  
   255  	r := request.NewClient()
   256  	res, err := r.Request(
   257  		"POST",
   258  		slave.ResolveReference(controller).String(),
   259  		bytes.NewReader(bodyByte),
   260  		request.WithTimeout(time.Duration(10)*time.Second),
   261  		request.WithCredential(
   262  			auth.HMACAuth{SecretKey: []byte(service.Secret)},
   263  			int64(model.GetIntSetting("slave_api_timeout", 60)),
   264  		),
   265  	).DecodeResponse()
   266  	if err != nil {
   267  		return serializer.ParamErr("Failed to connect to slave node: "+err.Error(), nil)
   268  	}
   269  
   270  	if res.Code != 0 {
   271  		return serializer.ParamErr("Successfully connected to slave node, but slave returns: "+res.Msg, nil)
   272  	}
   273  
   274  	return serializer.Response{}
   275  }
   276  
   277  // Add 添加存储策略
   278  func (service *AddPolicyService) Add() serializer.Response {
   279  	if service.Policy.Type != "local" && service.Policy.Type != "remote" {
   280  		service.Policy.DirNameRule = strings.TrimPrefix(service.Policy.DirNameRule, "/")
   281  	}
   282  
   283  	if service.Policy.ID > 0 {
   284  		if err := model.DB.Save(&service.Policy).Error; err != nil {
   285  			return serializer.DBErr("Failed to save policy", err)
   286  		}
   287  	} else {
   288  		if err := model.DB.Create(&service.Policy).Error; err != nil {
   289  			return serializer.DBErr("Failed to create policy", err)
   290  		}
   291  	}
   292  
   293  	service.Policy.ClearCache()
   294  
   295  	return serializer.Response{Data: service.Policy.ID}
   296  }
   297  
   298  // Test 测试本地路径
   299  func (service *PathTestService) Test() serializer.Response {
   300  	policy := model.Policy{DirNameRule: service.Path}
   301  	path := policy.GeneratePath(1, "/My File")
   302  	path = filepath.Join(path, "test.txt")
   303  	file, err := util.CreatNestedFile(util.RelativePath(path))
   304  	if err != nil {
   305  		return serializer.ParamErr(fmt.Sprintf("Failed to create \"%s\": %s", path, err.Error()), nil)
   306  	}
   307  
   308  	file.Close()
   309  	os.Remove(path)
   310  
   311  	return serializer.Response{}
   312  }
   313  
   314  // Policies 列出存储策略
   315  func (service *AdminListService) Policies() serializer.Response {
   316  	var res []model.Policy
   317  	total := 0
   318  
   319  	tx := model.DB.Model(&model.Policy{})
   320  	if service.OrderBy != "" {
   321  		tx = tx.Order(service.OrderBy)
   322  	}
   323  
   324  	for k, v := range service.Conditions {
   325  		tx = tx.Where(k+" = ?", v)
   326  	}
   327  
   328  	// 计算总数用于分页
   329  	tx.Count(&total)
   330  
   331  	// 查询记录
   332  	tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res)
   333  
   334  	// 统计每个策略的文件使用
   335  	statics := make(map[uint][2]int, len(res))
   336  	policyIds := make([]uint, 0, len(res))
   337  	for i := 0; i < len(res); i++ {
   338  		policyIds = append(policyIds, res[i].ID)
   339  	}
   340  
   341  	rows, _ := model.DB.Model(&model.File{}).Where("policy_id in (?)", policyIds).
   342  		Select("policy_id,count(id),sum(size)").Group("policy_id").Rows()
   343  
   344  	for rows.Next() {
   345  		policyId := uint(0)
   346  		total := [2]int{}
   347  		rows.Scan(&policyId, &total[0], &total[1])
   348  
   349  		statics[policyId] = total
   350  	}
   351  
   352  	return serializer.Response{Data: map[string]interface{}{
   353  		"total":   total,
   354  		"items":   res,
   355  		"statics": statics,
   356  	}}
   357  }