github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/ip_search.go (about)

     1  package common
     2  
     3  import (
     4  	"database/sql"
     5  
     6  	qgen "github.com/Azareal/Gosora/query_gen"
     7  )
     8  
     9  var IPSearch IPSearcher
    10  
    11  type IPSearcher interface {
    12  	Lookup(ip string) (uids []int, e error)
    13  }
    14  
    15  type DefaultIPSearcher struct {
    16  	searchUsers        *sql.Stmt
    17  	searchTopics       *sql.Stmt
    18  	searchReplies      *sql.Stmt
    19  	searchUsersReplies *sql.Stmt
    20  }
    21  
    22  // NewDefaultIPSearcher gives you a new instance of DefaultIPSearcher
    23  func NewDefaultIPSearcher() (*DefaultIPSearcher, error) {
    24  	acc := qgen.NewAcc()
    25  	uu := "users"
    26  	q := func(tbl string) *sql.Stmt {
    27  		return acc.Select(uu).Columns("uid").InQ("uid", acc.Select(tbl).Columns("createdBy").Where("ip=?")).Prepare()
    28  	}
    29  	return &DefaultIPSearcher{
    30  		searchUsers:        acc.Select(uu).Columns("uid").Where("last_ip=? OR last_ip LIKE CONCAT('%-',?)").Prepare(),
    31  		searchTopics:       q("topics"),
    32  		searchReplies:      q("replies"),
    33  		searchUsersReplies: q("users_replies"),
    34  	}, acc.FirstError()
    35  }
    36  
    37  func (s *DefaultIPSearcher) Lookup(ip string) (uids []int, e error) {
    38  	var uid int
    39  	reqUserList := make(map[int]bool)
    40  	runQuery2 := func(rows *sql.Rows, e error) error {
    41  		if e != nil {
    42  			return e
    43  		}
    44  		defer rows.Close()
    45  
    46  		for rows.Next() {
    47  			if e := rows.Scan(&uid); e != nil {
    48  				return e
    49  			}
    50  			reqUserList[uid] = true
    51  		}
    52  		return rows.Err()
    53  	}
    54  	runQuery := func(stmt *sql.Stmt) error {
    55  		return runQuery2(stmt.Query(ip))
    56  	}
    57  
    58  	e = runQuery2(s.searchUsers.Query(ip, ip))
    59  	if e != nil {
    60  		return uids, e
    61  	}
    62  	e = runQuery(s.searchTopics)
    63  	if e != nil {
    64  		return uids, e
    65  	}
    66  	e = runQuery(s.searchReplies)
    67  	if e != nil {
    68  		return uids, e
    69  	}
    70  	e = runQuery(s.searchUsersReplies)
    71  	if e != nil {
    72  		return uids, e
    73  	}
    74  
    75  	// Convert the user ID map to a slice, then bulk load the users
    76  	uids = make([]int, len(reqUserList))
    77  	var i int
    78  	for userID := range reqUserList {
    79  		uids[i] = userID
    80  		i++
    81  	}
    82  
    83  	return uids, nil
    84  }