github.com/aitjcize/Overlord@v0.0.0-20240314041920-104a804cf5e8/overlord/auth.go (about)

     1  // Copyright 2015 The Chromium OS Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  package overlord
     6  
     7  import (
     8  	"bufio"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net/http"
    14  	"os"
    15  	"regexp"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"golang.org/x/crypto/bcrypt"
    21  )
    22  
    23  const (
    24  	maxFailCount  = 10
    25  	blockDuration = 30 * time.Minute
    26  )
    27  
    28  func getRequestIP(r *http.Request) string {
    29  	if ips, ok := r.Header["X-Forwarded-For"]; ok {
    30  		return ips[len(ips)-1]
    31  	}
    32  	idx := strings.LastIndex(r.RemoteAddr, ":")
    33  	return r.RemoteAddr[:idx]
    34  }
    35  
    36  type basicAuthHTTPHandlerDecorator struct {
    37  	auth        *BasicAuth
    38  	handler     http.Handler
    39  	handlerFunc http.HandlerFunc
    40  }
    41  
    42  // ServeHTTP implements the http.Handler interface.
    43  func (d *basicAuthHTTPHandlerDecorator) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    44  	if d.auth.IsBlocked(r) {
    45  		http.Error(w, fmt.Sprintf("%s: %s", http.StatusText(http.StatusUnauthorized),
    46  			"too many retries"), http.StatusUnauthorized)
    47  		return
    48  	}
    49  
    50  	username, password, ok := r.BasicAuth()
    51  	if !ok {
    52  		d.auth.Unauthorized(w, r, "authorization failed", false)
    53  		return
    54  	}
    55  
    56  	pass, err := d.auth.Authenticate(username, password)
    57  	if !pass {
    58  		d.auth.Unauthorized(w, r, err.Error(), true)
    59  		return
    60  	}
    61  	d.auth.ResetFailCount(r)
    62  
    63  	if d.handler != nil {
    64  		d.handler.ServeHTTP(w, r)
    65  	} else {
    66  		d.handlerFunc(w, r)
    67  	}
    68  }
    69  
    70  // BasicAuth is a class that provide  WrapHandler and WrapHandlerFunc, which
    71  // turns a http.Handler to a HTTP basic-auth enabled http handler.
    72  type BasicAuth struct {
    73  	Realm   string
    74  	secrets map[string]string
    75  	Disable bool // Disable basic auth function, pass through
    76  
    77  	blockedIps  map[string]time.Time
    78  	failedCount map[string]int
    79  	mutex       sync.RWMutex
    80  }
    81  
    82  // NewBasicAuth creates a BasicAuth object
    83  func NewBasicAuth(realm, htpasswd string, disable bool) *BasicAuth {
    84  	secrets := make(map[string]string)
    85  
    86  	auth := &BasicAuth{
    87  		Realm:       realm,
    88  		secrets:     secrets,
    89  		Disable:     disable,
    90  		blockedIps:  make(map[string]time.Time),
    91  		failedCount: make(map[string]int),
    92  	}
    93  
    94  	f, err := os.Open(htpasswd)
    95  	if err != nil {
    96  		log.Printf("Warning: %s", err.Error())
    97  		auth.Disable = true
    98  		return auth
    99  	}
   100  
   101  	b := bufio.NewReader(f)
   102  	for {
   103  		line, _, err := b.ReadLine()
   104  		if err == io.EOF {
   105  			break
   106  		}
   107  		if line[0] == '#' {
   108  			continue
   109  		}
   110  		parts := strings.Split(string(line), ":")
   111  		if len(parts) != 2 {
   112  			continue
   113  		}
   114  		matched, err := regexp.Match("^\\$2[ay]\\$.*$", []byte(parts[1]))
   115  		if err != nil {
   116  			panic(err)
   117  		}
   118  		if !matched {
   119  			log.Printf("BasicAuth: user %s: password encryption scheme not supported, ignored.\n", parts[0])
   120  			continue
   121  		}
   122  		secrets[parts[0]] = parts[1]
   123  	}
   124  
   125  	return auth
   126  }
   127  
   128  // WrapHandler wraps an http.Hanlder and provide HTTP basic-auth.
   129  func (auth *BasicAuth) WrapHandler(h http.Handler) http.Handler {
   130  	if auth.Disable {
   131  		return h
   132  	}
   133  	return &basicAuthHTTPHandlerDecorator{auth, h, nil}
   134  }
   135  
   136  // WrapHandlerFunc wraps an http.HanlderFunc and provide HTTP basic-auth.
   137  func (auth *BasicAuth) WrapHandlerFunc(h http.HandlerFunc) http.Handler {
   138  	if auth.Disable {
   139  		return h
   140  	}
   141  	return &basicAuthHTTPHandlerDecorator{auth, nil, h}
   142  }
   143  
   144  // Authenticate authenticate an user with the provided user and passwd.
   145  func (auth *BasicAuth) Authenticate(user, passwd string) (bool, error) {
   146  	deniedError := errors.New("permission denied")
   147  
   148  	passwdHash, ok := auth.secrets[user]
   149  	if !ok {
   150  		return false, deniedError
   151  	}
   152  
   153  	if bcrypt.CompareHashAndPassword([]byte(passwdHash), []byte(passwd)) != nil {
   154  		return false, deniedError
   155  	}
   156  
   157  	return true, nil
   158  }
   159  
   160  // IsBlocked returns true if the given IP is blocked.
   161  func (auth *BasicAuth) IsBlocked(r *http.Request) bool {
   162  	ip := getRequestIP(r)
   163  
   164  	auth.mutex.RLock()
   165  	t, ok := auth.blockedIps[ip]
   166  	auth.mutex.RUnlock()
   167  	if !ok {
   168  		return false
   169  	}
   170  
   171  	if time.Now().Sub(t) < blockDuration {
   172  		log.Printf("BasicAuth: IP %s attempted to login, blocked\n", ip)
   173  		return true
   174  	}
   175  
   176  	// Unblock the user because of timeout
   177  	auth.mutex.Lock()
   178  	defer auth.mutex.Unlock()
   179  
   180  	delete(auth.failedCount, ip)
   181  	delete(auth.blockedIps, ip)
   182  
   183  	return false
   184  }
   185  
   186  // ResetFailCount resets the fail count for the given IP.
   187  func (auth *BasicAuth) ResetFailCount(r *http.Request) {
   188  	auth.mutex.Lock()
   189  	defer auth.mutex.Unlock()
   190  
   191  	ip := getRequestIP(r)
   192  	delete(auth.failedCount, ip)
   193  }
   194  
   195  // Unauthorized returns a 401 Unauthorized response.
   196  func (auth *BasicAuth) Unauthorized(w http.ResponseWriter, r *http.Request,
   197  	msg string, record bool) {
   198  
   199  	auth.mutex.Lock()
   200  	defer auth.mutex.Unlock()
   201  
   202  	// Record failure
   203  	if record {
   204  		ip := getRequestIP(r)
   205  		if _, ok := auth.failedCount[ip]; !ok {
   206  			auth.failedCount[ip] = 0
   207  		}
   208  		if ip != "127.0.0.1" {
   209  			// Only count for non-trusted IP.
   210  			auth.failedCount[ip]++
   211  		}
   212  
   213  		log.Printf("BasicAuth: IP %s failed to login, count: %d\n", ip,
   214  			auth.failedCount[ip])
   215  
   216  		if auth.failedCount[ip] >= maxFailCount {
   217  			auth.blockedIps[ip] = time.Now()
   218  			log.Printf("BasicAuth: IP %s (%s) is blocked\n", ip, r.UserAgent())
   219  		}
   220  	}
   221  
   222  	w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%s", auth.Realm))
   223  	http.Error(w, fmt.Sprintf("%s: %s", http.StatusText(http.StatusUnauthorized),
   224  		msg), http.StatusUnauthorized)
   225  }