github.com/TeaOSLab/EdgeNode@v1.3.8/internal/iplibrary/ip_list_sqlite.go (about)

     1  // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
     2  
     3  package iplibrary
     4  
     5  import (
     6  	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
     7  	"github.com/TeaOSLab/EdgeNode/internal/events"
     8  	"github.com/TeaOSLab/EdgeNode/internal/goman"
     9  	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
    10  	"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
    11  	"github.com/TeaOSLab/EdgeNode/internal/utils/idles"
    12  	"github.com/iwind/TeaGo/Tea"
    13  	"os"
    14  	"path/filepath"
    15  	"time"
    16  )
    17  
    18  type SQLiteIPList struct {
    19  	db *dbs.DB
    20  
    21  	itemTableName    string
    22  	versionTableName string
    23  
    24  	deleteExpiredItemsStmt   *dbs.Stmt
    25  	deleteItemStmt           *dbs.Stmt
    26  	insertItemStmt           *dbs.Stmt
    27  	selectItemsStmt          *dbs.Stmt
    28  	selectMaxItemVersionStmt *dbs.Stmt
    29  
    30  	selectVersionStmt *dbs.Stmt
    31  	updateVersionStmt *dbs.Stmt
    32  
    33  	cleanTicker *time.Ticker
    34  
    35  	dir string
    36  
    37  	isClosed bool
    38  }
    39  
    40  func NewSQLiteIPList() (*SQLiteIPList, error) {
    41  	var db = &SQLiteIPList{
    42  		itemTableName:    "ipItems",
    43  		versionTableName: "versions",
    44  		dir:              filepath.Clean(Tea.Root + "/data"),
    45  		cleanTicker:      time.NewTicker(24 * time.Hour),
    46  	}
    47  	err := db.init()
    48  	return db, err
    49  }
    50  
    51  func (this *SQLiteIPList) init() error {
    52  	// 检查目录是否存在
    53  	_, err := os.Stat(this.dir)
    54  	if err != nil {
    55  		err = os.MkdirAll(this.dir, 0777)
    56  		if err != nil {
    57  			return err
    58  		}
    59  		remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
    60  	}
    61  
    62  	var path = this.dir + "/ip_list.db"
    63  
    64  	db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
    65  	if err != nil {
    66  		return err
    67  	}
    68  	db.SetMaxOpenConns(1)
    69  
    70  	//_, err = db.Exec("VACUUM")
    71  	//if err != nil {
    72  	//	return err
    73  	//}
    74  
    75  	this.db = db
    76  
    77  	// 恢复数据库
    78  	var recoverEnv, _ = os.LookupEnv("EdgeRecover")
    79  	if len(recoverEnv) > 0 {
    80  		for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
    81  			_, _ = db.Exec(`REINDEX "` + indexName + `"`)
    82  		}
    83  	}
    84  
    85  	// 初始化数据库
    86  	_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
    87    "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
    88    "listId" integer DEFAULT 0,
    89    "listType" varchar(32),
    90    "isGlobal" integer(1) DEFAULT 0,
    91    "type" varchar(16),
    92    "itemId" integer DEFAULT 0,
    93    "ipFrom" varchar(64) DEFAULT 0,
    94    "ipTo" varchar(64) DEFAULT 0,
    95    "expiredAt" integer DEFAULT 0,
    96    "eventLevel" varchar(32),
    97    "isDeleted" integer(1) DEFAULT 0,
    98    "version" integer DEFAULT 0,
    99    "nodeId" integer DEFAULT 0,
   100    "serverId" integer DEFAULT 0
   101  );
   102  
   103  CREATE INDEX IF NOT EXISTS "ip_list_itemId"
   104  ON "` + this.itemTableName + `" (
   105    "itemId" ASC
   106  );
   107  
   108  CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
   109  ON "` + this.itemTableName + `" (
   110    "expiredAt" ASC
   111  );
   112  `)
   113  	if err != nil {
   114  		return err
   115  	}
   116  
   117  	_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
   118    "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
   119    "version" integer DEFAULT 0
   120  );
   121  `)
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	// 初始化SQL语句
   127  	this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE  "expiredAt">0 AND "expiredAt"<?`)
   128  	if err != nil {
   129  		return err
   130  	}
   131  
   132  	this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
   133  	if err != nil {
   134  		return err
   135  	}
   136  
   137  	this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
   138  	if err != nil {
   139  		return err
   140  	}
   141  
   142  	this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
   153  	if err != nil {
   154  		return err
   155  	}
   156  
   157  	this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
   158  	if err != nil {
   159  		return err
   160  	}
   161  
   162  	this.db = db
   163  
   164  	goman.New(func() {
   165  		events.OnClose(func() {
   166  			_ = this.Close()
   167  			this.cleanTicker.Stop()
   168  		})
   169  
   170  		idles.RunTicker(this.cleanTicker, func() {
   171  			deleteErr := this.DeleteExpiredItems()
   172  			if deleteErr != nil {
   173  				remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+deleteErr.Error())
   174  			}
   175  		})
   176  	})
   177  
   178  	return nil
   179  }
   180  
   181  // Name 数据库名称代号
   182  func (this *SQLiteIPList) Name() string {
   183  	return "sqlite"
   184  }
   185  
   186  // DeleteExpiredItems 删除过期的条目
   187  func (this *SQLiteIPList) DeleteExpiredItems() error {
   188  	if this.isClosed {
   189  		return nil
   190  	}
   191  
   192  	_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
   193  	return err
   194  }
   195  
   196  func (this *SQLiteIPList) AddItem(item *pb.IPItem) error {
   197  	if this.isClosed {
   198  		return nil
   199  	}
   200  
   201  	_, err := this.deleteItemStmt.Exec(item.Id)
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	// 如果是删除,则不再创建新记录
   207  	if item.IsDeleted {
   208  		return this.UpdateMaxVersion(item.Version)
   209  	}
   210  
   211  	_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
   212  	if err != nil {
   213  		return err
   214  	}
   215  
   216  	return this.UpdateMaxVersion(item.Version)
   217  }
   218  
   219  func (this *SQLiteIPList) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error) {
   220  	if this.isClosed {
   221  		return
   222  	}
   223  
   224  	rows, err := this.selectItemsStmt.Query(offset, size)
   225  	if err != nil {
   226  		return nil, false, err
   227  	}
   228  	defer func() {
   229  		_ = rows.Close()
   230  	}()
   231  
   232  	for rows.Next() {
   233  		//  "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
   234  		var pbItem = &pb.IPItem{}
   235  		err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
   236  		if err != nil {
   237  			return nil, false, err
   238  		}
   239  		items = append(items, pbItem)
   240  	}
   241  
   242  	goNext = int64(len(items)) == size
   243  	return
   244  }
   245  
   246  // ReadMaxVersion 读取当前最大版本号
   247  func (this *SQLiteIPList) ReadMaxVersion() (int64, error) {
   248  	if this.isClosed {
   249  		return 0, nil
   250  	}
   251  
   252  	// from version table
   253  	{
   254  		var row = this.selectVersionStmt.QueryRow()
   255  		if row == nil {
   256  			return 0, nil
   257  		}
   258  		var version int64
   259  		err := row.Scan(&version)
   260  		if err == nil {
   261  			return version, nil
   262  		}
   263  	}
   264  
   265  	// from items table
   266  	{
   267  		var row = this.selectMaxItemVersionStmt.QueryRow()
   268  		if row == nil {
   269  			return 0, nil
   270  		}
   271  		var version int64
   272  		err := row.Scan(&version)
   273  		if err != nil {
   274  			return 0, nil
   275  		}
   276  
   277  		return version, nil
   278  	}
   279  }
   280  
   281  // UpdateMaxVersion 修改版本号
   282  func (this *SQLiteIPList) UpdateMaxVersion(version int64) error {
   283  	if this.isClosed {
   284  		return nil
   285  	}
   286  
   287  	_, err := this.updateVersionStmt.Exec(version)
   288  	return err
   289  }
   290  
   291  func (this *SQLiteIPList) Close() error {
   292  	this.isClosed = true
   293  
   294  	if this.db != nil {
   295  		for _, stmt := range []*dbs.Stmt{
   296  			this.deleteExpiredItemsStmt,
   297  			this.deleteItemStmt,
   298  			this.insertItemStmt,
   299  			this.selectItemsStmt,
   300  			this.selectMaxItemVersionStmt, // ipItems table
   301  
   302  			this.selectVersionStmt, // versions table
   303  			this.updateVersionStmt,
   304  		} {
   305  			if stmt != nil {
   306  				_ = stmt.Close()
   307  			}
   308  		}
   309  
   310  		return this.db.Close()
   311  	}
   312  	return nil
   313  }