github.com/cnotch/ipchub@v1.1.0/provider/auth/manager.go (about)

     1  // Copyright (c) 2019,CAOHONGJU All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package auth
     6  
     7  import (
     8  	"strings"
     9  	"sync"
    10  
    11  	"github.com/cnotch/xlog"
    12  )
    13  
    14  var globalM = &manager{
    15  	m: make(map[string]*User),
    16  }
    17  
    18  func init() {
    19  	// 默认为内存提供者,避免没有初始化全局函数调用问题
    20  	globalM.Reset(&memProvider{})
    21  }
    22  
    23  // Reset 重置用户提供者
    24  func Reset(provider UserProvider) {
    25  	globalM.Reset(provider)
    26  }
    27  
    28  // All 获取所有的用户
    29  func All() []*User {
    30  	return globalM.All()
    31  }
    32  
    33  // Get 获取取指定名称的用户
    34  func Get(userName string) *User {
    35  	return globalM.Get(userName)
    36  }
    37  
    38  // Del 删除指定名称的用户
    39  func Del(userName string) error {
    40  	return globalM.Del(userName)
    41  }
    42  
    43  // Save 保存用户
    44  func Save(src *User, updatePassword bool) error {
    45  	return globalM.Save(src, updatePassword)
    46  }
    47  
    48  // Flush 刷新用户
    49  func Flush() error {
    50  	return globalM.Flush()
    51  }
    52  
    53  type manager struct {
    54  	lock sync.RWMutex
    55  	m    map[string]*User // 用户map
    56  	l    []*User          // 用户list
    57  
    58  	saves   []*User // 自上次Flush后新的保存和删除的用户
    59  	removes []*User
    60  
    61  	provider UserProvider
    62  }
    63  
    64  func (m *manager) Reset(provider UserProvider) {
    65  	m.lock.Lock()
    66  	defer m.lock.Unlock()
    67  
    68  	m.m = make(map[string]*User)
    69  	m.l = m.l[:0]
    70  	m.saves = m.saves[:0]
    71  	m.removes = m.removes[:0]
    72  	m.provider = provider
    73  
    74  	users, err := provider.LoadAll()
    75  	if err != nil {
    76  		panic("Load user fail")
    77  	}
    78  
    79  	if cap(m.l) < len(users) {
    80  		m.l = make([]*User, 0, len(users))
    81  	}
    82  
    83  	// 加入缓存
    84  	for _, u := range users {
    85  		if err := u.init(); err != nil {
    86  			xlog.Warnf("user table init failed: `%v`", err)
    87  			continue // 忽略错误的配置
    88  		}
    89  		m.m[u.Name] = u
    90  		m.l = append(m.l, u)
    91  	}
    92  }
    93  
    94  func (m *manager) Get(userName string) *User {
    95  	m.lock.RLock()
    96  	defer m.lock.RUnlock()
    97  
    98  	userName = strings.ToLower(userName)
    99  	u, ok := m.m[userName]
   100  	if ok {
   101  		return u
   102  	}
   103  	return nil
   104  }
   105  
   106  func (m *manager) Del(userName string) error {
   107  	m.lock.Lock()
   108  	defer m.lock.Unlock()
   109  
   110  	userName = strings.ToLower(userName)
   111  	u, ok := m.m[userName]
   112  
   113  	if ok {
   114  		delete(m.m, userName)
   115  
   116  		// 从完整列表中删除
   117  		for i, u2 := range m.l {
   118  			if u.Name == u2.Name {
   119  				m.l = append(m.l[:i], m.l[i+1:]...)
   120  				break
   121  			}
   122  		}
   123  
   124  		// 从保存列表中删除
   125  		for i, u2 := range m.saves {
   126  			if u.Name == u2.Name {
   127  				m.saves = append(m.saves[:i], m.saves[i+1:]...)
   128  				break
   129  			}
   130  		}
   131  
   132  		m.removes = append(m.removes, u)
   133  	}
   134  	return nil
   135  }
   136  
   137  func (m *manager) Save(newu *User, updatePassword bool) error {
   138  	m.lock.Lock()
   139  	defer m.lock.Unlock()
   140  
   141  	err := newu.init()
   142  	if err != nil {
   143  		return err
   144  	}
   145  
   146  	u, ok := m.m[newu.Name]
   147  
   148  	if ok { // 更新
   149  		u.CopyFrom(newu, updatePassword)
   150  
   151  		save := true
   152  		// 如果保存列表存在,不新增
   153  		for _, u2 := range m.saves {
   154  			if u.Name == u2.Name {
   155  				save = false
   156  				break
   157  			}
   158  		}
   159  
   160  		if save {
   161  			m.saves = append(m.saves, u)
   162  		}
   163  	} else { // 新增
   164  		u = newu
   165  		m.m[u.Name] = u
   166  
   167  		m.l = append(m.l, u)
   168  		m.saves = append(m.saves, u)
   169  
   170  		for i, u2 := range m.removes {
   171  			if u.Name == u2.Name {
   172  				m.removes = append(m.removes[:i], m.removes[i+1:]...)
   173  				break
   174  			}
   175  		}
   176  	}
   177  	return nil
   178  }
   179  
   180  func (m *manager) Flush() error {
   181  	m.lock.Lock()
   182  	defer m.lock.Unlock()
   183  
   184  	if len(m.saves)+len(m.removes) == 0 {
   185  		return nil
   186  	}
   187  
   188  	err := m.provider.Flush(m.l, m.saves, m.removes)
   189  	if err != nil {
   190  		return err
   191  	}
   192  
   193  	m.saves = m.saves[:0]
   194  	m.removes = m.removes[:0]
   195  	return nil
   196  }
   197  
   198  func (m *manager) All() []*User {
   199  	m.lock.RLock()
   200  	defer m.lock.RUnlock()
   201  
   202  	users := make([]*User, len(m.l))
   203  	copy(users, m.l)
   204  	return users
   205  }