github.com/pusher/oauth2_proxy@v3.2.0+incompatible/validator.go (about)

     1  package main
     2  
     3  import (
     4  	"encoding/csv"
     5  	"fmt"
     6  	"log"
     7  	"os"
     8  	"strings"
     9  	"sync/atomic"
    10  	"unsafe"
    11  )
    12  
    13  // UserMap holds information from the authenticated emails file
    14  type UserMap struct {
    15  	usersFile string
    16  	m         unsafe.Pointer
    17  }
    18  
    19  // NewUserMap parses the authenticated emails file into a new UserMap
    20  func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
    21  	um := &UserMap{usersFile: usersFile}
    22  	m := make(map[string]bool)
    23  	atomic.StorePointer(&um.m, unsafe.Pointer(&m))
    24  	if usersFile != "" {
    25  		log.Printf("using authenticated emails file %s", usersFile)
    26  		WatchForUpdates(usersFile, done, func() {
    27  			um.LoadAuthenticatedEmailsFile()
    28  			onUpdate()
    29  		})
    30  		um.LoadAuthenticatedEmailsFile()
    31  	}
    32  	return um
    33  }
    34  
    35  // IsValid checks if an email is allowed
    36  func (um *UserMap) IsValid(email string) (result bool) {
    37  	m := *(*map[string]bool)(atomic.LoadPointer(&um.m))
    38  	_, result = m[email]
    39  	return
    40  }
    41  
    42  // LoadAuthenticatedEmailsFile loads the authenticated emails file from disk
    43  // and parses the contents as CSV
    44  func (um *UserMap) LoadAuthenticatedEmailsFile() {
    45  	r, err := os.Open(um.usersFile)
    46  	if err != nil {
    47  		log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err)
    48  	}
    49  	defer r.Close()
    50  	csvReader := csv.NewReader(r)
    51  	csvReader.Comma = ','
    52  	csvReader.Comment = '#'
    53  	csvReader.TrimLeadingSpace = true
    54  	records, err := csvReader.ReadAll()
    55  	if err != nil {
    56  		log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err)
    57  		return
    58  	}
    59  	updated := make(map[string]bool)
    60  	for _, r := range records {
    61  		address := strings.ToLower(strings.TrimSpace(r[0]))
    62  		updated[address] = true
    63  	}
    64  	atomic.StorePointer(&um.m, unsafe.Pointer(&updated))
    65  }
    66  
    67  func newValidatorImpl(domains []string, usersFile string,
    68  	done <-chan bool, onUpdate func()) func(string) bool {
    69  	validUsers := NewUserMap(usersFile, done, onUpdate)
    70  
    71  	var allowAll bool
    72  	for i, domain := range domains {
    73  		if domain == "*" {
    74  			allowAll = true
    75  			continue
    76  		}
    77  		domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
    78  	}
    79  
    80  	validator := func(email string) (valid bool) {
    81  		if email == "" {
    82  			return
    83  		}
    84  		email = strings.ToLower(email)
    85  		for _, domain := range domains {
    86  			valid = valid || strings.HasSuffix(email, domain)
    87  		}
    88  		if !valid {
    89  			valid = validUsers.IsValid(email)
    90  		}
    91  		if allowAll {
    92  			valid = true
    93  		}
    94  		return valid
    95  	}
    96  	return validator
    97  }
    98  
    99  // NewValidator constructs a function to validate email addresses
   100  func NewValidator(domains []string, usersFile string) func(string) bool {
   101  	return newValidatorImpl(domains, usersFile, nil, func() {})
   102  }