github.com/SupenBysz/gf-admin-community@v0.7.4/internal/logic/sys_casbin/sys_adapter.go (about)

     1  // Copyright 2017 The sys_casbin Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this sys_file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sys_casbin
    16  
    17  import (
    18  	"fmt"
    19  	"github.com/SupenBysz/gf-admin-community/sys_model/sys_dao"
    20  	"runtime"
    21  
    22  	"github.com/casbin/casbin/v2/model"
    23  	"github.com/casbin/casbin/v2/persist"
    24  	"github.com/gogf/gf/v2/database/gdb"
    25  	"github.com/gogf/gf/v2/os/gctx"
    26  )
    27  
    28  type CasbinRule struct {
    29  	PType string `json:"ptype"`
    30  	V0    string `json:"v0"`
    31  	V1    string `json:"v1"`
    32  	V2    string `json:"v2"`
    33  	V3    string `json:"v3"`
    34  	V4    string `json:"v4"`
    35  	V5    string `json:"v5"`
    36  }
    37  
    38  // Adapter represents the gdb adapter for policy storage.
    39  type Adapter struct {
    40  	DriverName     string
    41  	DataSourceName string
    42  	TableName      string
    43  	Db             gdb.DB
    44  }
    45  
    46  // finalizer is the destructor for Adapter.
    47  func finalizer(a *Adapter) {
    48  	// 注意不用的时候不需要使用Close方法关闭数据库连接(并且gdb也没有提供Close方法),
    49  	// 数据库引擎底层采用了链接池设计,当链接不再使用时会自动关闭
    50  	a.Db = nil
    51  }
    52  
    53  // NewAdapter is the constructor for Adapter.
    54  func NewAdapter(driverName string, dataSourceName string) (*Adapter, error) {
    55  	a := &Adapter{}
    56  	a.DriverName = driverName
    57  	a.DataSourceName = dataSourceName
    58  	a.TableName = "casbin_rule"
    59  
    60  	// Open the DB, create it if not existed.
    61  	err := a.open()
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	// Call the destructor when the object is released.
    67  	runtime.SetFinalizer(a, finalizer)
    68  
    69  	return a, nil
    70  }
    71  
    72  // NewAdapterFromOptions is the constructor for Adapter with existed connection
    73  func NewAdapterFromOptions(adapter *Adapter) (*Adapter, error) {
    74  
    75  	if adapter.TableName == "" {
    76  		adapter.TableName = "casbin_rule"
    77  	}
    78  	if adapter.Db == nil {
    79  		err := adapter.open()
    80  		if err != nil {
    81  			return nil, err
    82  		}
    83  
    84  		runtime.SetFinalizer(adapter, finalizer)
    85  	}
    86  	return adapter, nil
    87  }
    88  
    89  func (a *Adapter) open() error {
    90  	a.Db = sys_dao.SysCasbin.DB()
    91  
    92  	return a.createTable()
    93  }
    94  
    95  func (a *Adapter) close() error {
    96  	// 注意不用的时候不需要使用Close方法关闭数据库连接(并且gdb也没有提供Close方法),
    97  	// 数据库引擎底层采用了链接池设计,当链接不再使用时会自动关闭
    98  	a.Db = nil
    99  	return nil
   100  }
   101  
   102  func (a *Adapter) createTable() error {
   103  	_, err := a.Db.Exec(gctx.New(), fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (ptype VARCHAR(10), v0 VARCHAR(256), v1 VARCHAR(256), v2 VARCHAR(256), v3 VARCHAR(256), v4 VARCHAR(256), v5 VARCHAR(256))", a.TableName))
   104  	return err
   105  }
   106  
   107  func (a *Adapter) dropTable() error {
   108  	_, err := a.Db.Exec(gctx.New(), fmt.Sprintf("DROP TABLE %s", a.TableName))
   109  	return err
   110  }
   111  
   112  func loadPolicyLine(line CasbinRule, model model.Model) {
   113  	lineText := line.PType
   114  	if line.V0 != "" {
   115  		lineText += ", " + line.V0
   116  	}
   117  	if line.V1 != "" {
   118  		lineText += ", " + line.V1
   119  	}
   120  	if line.V2 != "" {
   121  		lineText += ", " + line.V2
   122  	}
   123  	if line.V3 != "" {
   124  		lineText += ", " + line.V3
   125  	}
   126  	if line.V4 != "" {
   127  		lineText += ", " + line.V4
   128  	}
   129  	if line.V5 != "" {
   130  		lineText += ", " + line.V5
   131  	}
   132  
   133  	persist.LoadPolicyLine(lineText, model)
   134  }
   135  
   136  // LoadPolicy loads policy from database.
   137  func (a *Adapter) LoadPolicy(model model.Model) error {
   138  	var lines []CasbinRule
   139  
   140  	if err := a.Db.Model(a.TableName).Scan(&lines); err != nil {
   141  		return err
   142  	}
   143  
   144  	for _, line := range lines {
   145  		loadPolicyLine(line, model)
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  func savePolicyLine(ptype string, rule []string) CasbinRule {
   152  	line := CasbinRule{}
   153  
   154  	line.PType = ptype
   155  	if len(rule) > 0 {
   156  		line.V0 = rule[0]
   157  	}
   158  	if len(rule) > 1 {
   159  		line.V1 = rule[1]
   160  	}
   161  	if len(rule) > 2 {
   162  		line.V2 = rule[2]
   163  	}
   164  	if len(rule) > 3 {
   165  		line.V3 = rule[3]
   166  	}
   167  	if len(rule) > 4 {
   168  		line.V4 = rule[4]
   169  	}
   170  	if len(rule) > 5 {
   171  		line.V5 = rule[5]
   172  	}
   173  
   174  	return line
   175  }
   176  
   177  // SavePolicy saves policy to database.
   178  func (a *Adapter) SavePolicy(model model.Model) error {
   179  	err := a.dropTable()
   180  	if err != nil {
   181  		return err
   182  	}
   183  	err = a.createTable()
   184  	if err != nil {
   185  		return err
   186  	}
   187  
   188  	for ptype, ast := range model["p"] {
   189  		for _, rule := range ast.Policy {
   190  			line := savePolicyLine(ptype, rule)
   191  			_, err := a.Db.Model(a.TableName).Data(&line).Insert()
   192  			if err != nil {
   193  				return err
   194  			}
   195  		}
   196  	}
   197  
   198  	for ptype, ast := range model["g"] {
   199  		for _, rule := range ast.Policy {
   200  			line := savePolicyLine(ptype, rule)
   201  			_, err := a.Db.Model(a.TableName).Data(&line).Insert()
   202  			if err != nil {
   203  				return err
   204  			}
   205  		}
   206  	}
   207  
   208  	return nil
   209  }
   210  
   211  // AddPolicy adds a policy rule to the storage.
   212  func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
   213  	line := savePolicyLine(ptype, rule)
   214  	_, err := a.Db.Model(a.TableName).Data(&line).Insert()
   215  	return err
   216  }
   217  
   218  // RemovePolicy removes a policy rule from the storage.
   219  func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
   220  	line := savePolicyLine(ptype, rule)
   221  	err := rawDelete(a, line)
   222  	return err
   223  }
   224  
   225  // RemoveFilteredPolicy removes policy rules that match the filter from the storage.
   226  func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
   227  	line := CasbinRule{}
   228  
   229  	line.PType = ptype
   230  	if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
   231  		line.V0 = fieldValues[0-fieldIndex]
   232  	}
   233  	if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
   234  		line.V1 = fieldValues[1-fieldIndex]
   235  	}
   236  	if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
   237  		line.V2 = fieldValues[2-fieldIndex]
   238  	}
   239  	if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
   240  		line.V3 = fieldValues[3-fieldIndex]
   241  	}
   242  	if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
   243  		line.V4 = fieldValues[4-fieldIndex]
   244  	}
   245  	if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
   246  		line.V5 = fieldValues[5-fieldIndex]
   247  	}
   248  	err := rawDelete(a, line)
   249  	return err
   250  }
   251  
   252  func rawDelete(a *Adapter, line CasbinRule) error {
   253  	db := a.Db.Model(a.TableName)
   254  
   255  	db.Where("ptype = ?", line.PType)
   256  	if line.V0 != "" {
   257  		db.Where("v0 = ?", line.V0)
   258  	}
   259  	if line.V1 != "" {
   260  		db.Where("v1 = ?", line.V1)
   261  	}
   262  	if line.V2 != "" {
   263  		db.Where("v2 = ?", line.V2)
   264  	}
   265  	if line.V3 != "" {
   266  		db.Where("v3 = ?", line.V3)
   267  	}
   268  	if line.V4 != "" {
   269  		db.Where("v4 = ?", line.V4)
   270  	}
   271  	if line.V5 != "" {
   272  		db.Where("v5 = ?", line.V5)
   273  	}
   274  
   275  	_, err := db.Delete()
   276  	return err
   277  }