github.com/pusher/oauth2_proxy@v3.2.0+incompatible/validator.go (about) 1 package main 2 3 import ( 4 "encoding/csv" 5 "fmt" 6 "log" 7 "os" 8 "strings" 9 "sync/atomic" 10 "unsafe" 11 ) 12 13 // UserMap holds information from the authenticated emails file 14 type UserMap struct { 15 usersFile string 16 m unsafe.Pointer 17 } 18 19 // NewUserMap parses the authenticated emails file into a new UserMap 20 func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { 21 um := &UserMap{usersFile: usersFile} 22 m := make(map[string]bool) 23 atomic.StorePointer(&um.m, unsafe.Pointer(&m)) 24 if usersFile != "" { 25 log.Printf("using authenticated emails file %s", usersFile) 26 WatchForUpdates(usersFile, done, func() { 27 um.LoadAuthenticatedEmailsFile() 28 onUpdate() 29 }) 30 um.LoadAuthenticatedEmailsFile() 31 } 32 return um 33 } 34 35 // IsValid checks if an email is allowed 36 func (um *UserMap) IsValid(email string) (result bool) { 37 m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) 38 _, result = m[email] 39 return 40 } 41 42 // LoadAuthenticatedEmailsFile loads the authenticated emails file from disk 43 // and parses the contents as CSV 44 func (um *UserMap) LoadAuthenticatedEmailsFile() { 45 r, err := os.Open(um.usersFile) 46 if err != nil { 47 log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) 48 } 49 defer r.Close() 50 csvReader := csv.NewReader(r) 51 csvReader.Comma = ',' 52 csvReader.Comment = '#' 53 csvReader.TrimLeadingSpace = true 54 records, err := csvReader.ReadAll() 55 if err != nil { 56 log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) 57 return 58 } 59 updated := make(map[string]bool) 60 for _, r := range records { 61 address := strings.ToLower(strings.TrimSpace(r[0])) 62 updated[address] = true 63 } 64 atomic.StorePointer(&um.m, unsafe.Pointer(&updated)) 65 } 66 67 func newValidatorImpl(domains []string, usersFile string, 68 done <-chan bool, onUpdate func()) func(string) bool { 69 validUsers := NewUserMap(usersFile, done, onUpdate) 70 71 var allowAll bool 72 for i, domain := range domains { 73 if domain == "*" { 74 allowAll = true 75 continue 76 } 77 domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) 78 } 79 80 validator := func(email string) (valid bool) { 81 if email == "" { 82 return 83 } 84 email = strings.ToLower(email) 85 for _, domain := range domains { 86 valid = valid || strings.HasSuffix(email, domain) 87 } 88 if !valid { 89 valid = validUsers.IsValid(email) 90 } 91 if allowAll { 92 valid = true 93 } 94 return valid 95 } 96 return validator 97 } 98 99 // NewValidator constructs a function to validate email addresses 100 func NewValidator(domains []string, usersFile string) func(string) bool { 101 return newValidatorImpl(domains, usersFile, nil, func() {}) 102 }