github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/service/admin/policy.go (about) 1 package admin 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" 9 "net/http" 10 "net/url" 11 "os" 12 "path/filepath" 13 "strconv" 14 "strings" 15 "time" 16 17 model "github.com/cloudreve/Cloudreve/v3/models" 18 "github.com/cloudreve/Cloudreve/v3/pkg/auth" 19 "github.com/cloudreve/Cloudreve/v3/pkg/cache" 20 "github.com/cloudreve/Cloudreve/v3/pkg/conf" 21 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos" 22 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" 23 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss" 24 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" 25 "github.com/cloudreve/Cloudreve/v3/pkg/request" 26 "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 27 "github.com/cloudreve/Cloudreve/v3/pkg/util" 28 "github.com/gin-gonic/gin" 29 cossdk "github.com/tencentyun/cos-go-sdk-v5" 30 ) 31 32 // PathTestService 本地路径测试服务 33 type PathTestService struct { 34 Path string `json:"path" binding:"required"` 35 } 36 37 // SlaveTestService 从机测试服务 38 type SlaveTestService struct { 39 Secret string `json:"secret" binding:"required"` 40 Server string `json:"server" binding:"required"` 41 } 42 43 // SlavePingService 从机相应ping 44 type SlavePingService struct { 45 Callback string `json:"callback" binding:"required"` 46 } 47 48 // AddPolicyService 存储策略添加服务 49 type AddPolicyService struct { 50 Policy model.Policy `json:"policy" binding:"required"` 51 } 52 53 // PolicyService 存储策略ID服务 54 type PolicyService struct { 55 ID uint `uri:"id" json:"id" binding:"required"` 56 Region string `json:"region"` 57 } 58 59 // Delete 删除存储策略 60 func (service *PolicyService) Delete() serializer.Response { 61 // 禁止删除默认策略 62 if service.ID == 1 { 63 return serializer.Err(serializer.CodeDeleteDefaultPolicy, "", nil) 64 } 65 66 policy, err := model.GetPolicyByID(service.ID) 67 if err != nil { 68 return serializer.Err(serializer.CodePolicyNotExist, "", err) 69 } 70 71 // 检查是否有文件使用 72 total := 0 73 row := model.DB.Model(&model.File{}).Where("policy_id = ?", service.ID). 74 Select("count(id)").Row() 75 row.Scan(&total) 76 if total > 0 { 77 return serializer.Err(serializer.CodePolicyUsedByFiles, strconv.Itoa(total), nil) 78 } 79 80 // 检查用户组使用 81 var groups []model.Group 82 model.DB.Model(&model.Group{}).Where( 83 "policies like ?", 84 fmt.Sprintf("%%[%d]%%", service.ID), 85 ).Find(&groups) 86 87 if len(groups) > 0 { 88 return serializer.Err(serializer.CodePolicyUsedByGroups, strconv.Itoa(len(groups)), nil) 89 } 90 91 model.DB.Delete(&policy) 92 policy.ClearCache() 93 94 return serializer.Response{} 95 } 96 97 // Get 获取存储策略详情 98 func (service *PolicyService) Get() serializer.Response { 99 policy, err := model.GetPolicyByID(service.ID) 100 if err != nil { 101 return serializer.Err(serializer.CodePolicyNotExist, "", err) 102 } 103 104 return serializer.Response{Data: policy} 105 } 106 107 // GetOAuth 获取 OneDrive OAuth 地址 108 func (service *PolicyService) GetOAuth(c *gin.Context, policyType string) serializer.Response { 109 policy, err := model.GetPolicyByID(service.ID) 110 if err != nil || policy.Type != policyType { 111 return serializer.Err(serializer.CodePolicyNotExist, "", nil) 112 } 113 114 util.SetSession(c, map[string]interface{}{ 115 policyType + "_oauth_policy": policy.ID, 116 }) 117 118 var redirect string 119 switch policy.Type { 120 case "onedrive": 121 client, err := onedrive.NewClient(&policy) 122 if err != nil { 123 return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err) 124 } 125 126 redirect = client.OAuthURL(context.Background(), []string{ 127 "offline_access", 128 "files.readwrite.all", 129 }) 130 case "googledrive": 131 client, err := googledrive.NewClient(&policy) 132 if err != nil { 133 return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err) 134 } 135 136 redirect = client.OAuthURL(context.Background(), googledrive.RequiredScope) 137 } 138 139 // Delete token cache 140 cache.Deletes([]string{policy.BucketName}, policyType+"_") 141 142 return serializer.Response{Data: redirect} 143 } 144 145 // AddSCF 创建回调云函数 146 func (service *PolicyService) AddSCF() serializer.Response { 147 policy, err := model.GetPolicyByID(service.ID) 148 if err != nil { 149 return serializer.Err(serializer.CodePolicyNotExist, "", nil) 150 } 151 152 if err := cos.CreateSCF(&policy, service.Region); err != nil { 153 return serializer.ParamErr("Failed to create SCF function", err) 154 } 155 156 return serializer.Response{} 157 } 158 159 // AddCORS 创建跨域策略 160 func (service *PolicyService) AddCORS() serializer.Response { 161 policy, err := model.GetPolicyByID(service.ID) 162 if err != nil { 163 return serializer.Err(serializer.CodePolicyNotExist, "", nil) 164 } 165 166 switch policy.Type { 167 case "oss": 168 handler, err := oss.NewDriver(&policy) 169 if err != nil { 170 return serializer.Err(serializer.CodeAddCORS, "", err) 171 } 172 if err := handler.CORS(); err != nil { 173 return serializer.Err(serializer.CodeAddCORS, "", err) 174 } 175 case "cos": 176 u, _ := url.Parse(policy.Server) 177 b := &cossdk.BaseURL{BucketURL: u} 178 handler := cos.Driver{ 179 Policy: &policy, 180 HTTPClient: request.NewClient(), 181 Client: cossdk.NewClient(b, &http.Client{ 182 Transport: &cossdk.AuthorizationTransport{ 183 SecretID: policy.AccessKey, 184 SecretKey: policy.SecretKey, 185 }, 186 }), 187 } 188 189 if err := handler.CORS(); err != nil { 190 return serializer.Err(serializer.CodeAddCORS, "", err) 191 } 192 case "s3": 193 handler, err := s3.NewDriver(&policy) 194 if err != nil { 195 return serializer.Err(serializer.CodeAddCORS, "", err) 196 } 197 198 if err := handler.CORS(); err != nil { 199 return serializer.Err(serializer.CodeAddCORS, "", err) 200 } 201 default: 202 return serializer.Err(serializer.CodePolicyNotAllowed, "", nil) 203 } 204 205 return serializer.Response{} 206 } 207 208 // Test 从机响应ping 209 func (service *SlavePingService) Test() serializer.Response { 210 master, err := url.Parse(service.Callback) 211 if err != nil { 212 return serializer.ParamErr("Failed to parse Master site url: "+err.Error(), nil) 213 } 214 215 controller, _ := url.Parse("/api/v3/site/ping") 216 217 r := request.NewClient() 218 res, err := r.Request( 219 "GET", 220 master.ResolveReference(controller).String(), 221 nil, 222 request.WithTimeout(time.Duration(10)*time.Second), 223 ).DecodeResponse() 224 225 if err != nil { 226 return serializer.Err(serializer.CodeSlavePingMaster, err.Error(), nil) 227 } 228 229 version := conf.BackendVersion 230 if conf.IsPro == "true" { 231 version += "-pro" 232 } 233 if res.Data.(string) != version { 234 return serializer.Err(serializer.CodeVersionMismatch, "Master: "+res.Data.(string)+", Slave: "+version, nil) 235 } 236 237 return serializer.Response{} 238 } 239 240 // Test 测试从机通信 241 func (service *SlaveTestService) Test() serializer.Response { 242 slave, err := url.Parse(service.Server) 243 if err != nil { 244 return serializer.ParamErr("Failed to parse slave node server URL: "+err.Error(), nil) 245 } 246 247 controller, _ := url.Parse("/api/v3/slave/ping") 248 249 // 请求正文 250 body := map[string]string{ 251 "callback": model.GetSiteURL().String(), 252 } 253 bodyByte, _ := json.Marshal(body) 254 255 r := request.NewClient() 256 res, err := r.Request( 257 "POST", 258 slave.ResolveReference(controller).String(), 259 bytes.NewReader(bodyByte), 260 request.WithTimeout(time.Duration(10)*time.Second), 261 request.WithCredential( 262 auth.HMACAuth{SecretKey: []byte(service.Secret)}, 263 int64(model.GetIntSetting("slave_api_timeout", 60)), 264 ), 265 ).DecodeResponse() 266 if err != nil { 267 return serializer.ParamErr("Failed to connect to slave node: "+err.Error(), nil) 268 } 269 270 if res.Code != 0 { 271 return serializer.ParamErr("Successfully connected to slave node, but slave returns: "+res.Msg, nil) 272 } 273 274 return serializer.Response{} 275 } 276 277 // Add 添加存储策略 278 func (service *AddPolicyService) Add() serializer.Response { 279 if service.Policy.Type != "local" && service.Policy.Type != "remote" { 280 service.Policy.DirNameRule = strings.TrimPrefix(service.Policy.DirNameRule, "/") 281 } 282 283 if service.Policy.ID > 0 { 284 if err := model.DB.Save(&service.Policy).Error; err != nil { 285 return serializer.DBErr("Failed to save policy", err) 286 } 287 } else { 288 if err := model.DB.Create(&service.Policy).Error; err != nil { 289 return serializer.DBErr("Failed to create policy", err) 290 } 291 } 292 293 service.Policy.ClearCache() 294 295 return serializer.Response{Data: service.Policy.ID} 296 } 297 298 // Test 测试本地路径 299 func (service *PathTestService) Test() serializer.Response { 300 policy := model.Policy{DirNameRule: service.Path} 301 path := policy.GeneratePath(1, "/My File") 302 path = filepath.Join(path, "test.txt") 303 file, err := util.CreatNestedFile(util.RelativePath(path)) 304 if err != nil { 305 return serializer.ParamErr(fmt.Sprintf("Failed to create \"%s\": %s", path, err.Error()), nil) 306 } 307 308 file.Close() 309 os.Remove(path) 310 311 return serializer.Response{} 312 } 313 314 // Policies 列出存储策略 315 func (service *AdminListService) Policies() serializer.Response { 316 var res []model.Policy 317 total := 0 318 319 tx := model.DB.Model(&model.Policy{}) 320 if service.OrderBy != "" { 321 tx = tx.Order(service.OrderBy) 322 } 323 324 for k, v := range service.Conditions { 325 tx = tx.Where(k+" = ?", v) 326 } 327 328 // 计算总数用于分页 329 tx.Count(&total) 330 331 // 查询记录 332 tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) 333 334 // 统计每个策略的文件使用 335 statics := make(map[uint][2]int, len(res)) 336 policyIds := make([]uint, 0, len(res)) 337 for i := 0; i < len(res); i++ { 338 policyIds = append(policyIds, res[i].ID) 339 } 340 341 rows, _ := model.DB.Model(&model.File{}).Where("policy_id in (?)", policyIds). 342 Select("policy_id,count(id),sum(size)").Group("policy_id").Rows() 343 344 for rows.Next() { 345 policyId := uint(0) 346 total := [2]int{} 347 rows.Scan(&policyId, &total[0], &total[1]) 348 349 statics[policyId] = total 350 } 351 352 return serializer.Response{Data: map[string]interface{}{ 353 "total": total, 354 "items": res, 355 "statics": statics, 356 }} 357 }