github.com/e154/smart-home@v0.17.2-0.20240311175135-e530a6e5cd45/common/web/digest.go (about)

     1  // This file is part of the Smart Home
     2  // Program complex distribution https://github.com/e154/smart-home
     3  // Copyright (C) 2023, Filippov Alex
     4  //
     5  // This library is free software: you can redistribute it and/or
     6  // modify it under the terms of the GNU Lesser General Public
     7  // License as published by the Free Software Foundation; either
     8  // version 3 of the License, or (at your option) any later version.
     9  //
    10  // This library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    13  // Library General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public
    16  // License along with this library.  If not, see
    17  // <https://www.gnu.org/licenses/>.
    18  
    19  package web
    20  
    21  import (
    22  	"crypto/md5"
    23  	"crypto/rand"
    24  	"encoding/base64"
    25  	"fmt"
    26  	"io"
    27  	"net/http"
    28  	"net/url"
    29  	"strings"
    30  )
    31  
    32  type myjar struct {
    33  	jar map[string][]*http.Cookie
    34  }
    35  
    36  // DigestHeaders tracks the state of authentication
    37  type DigestHeaders struct {
    38  	Realm     string
    39  	Qop       string
    40  	Method    string
    41  	Nonce     string
    42  	Opaque    string
    43  	Algorithm string
    44  	HA1       string
    45  	HA2       string
    46  	Cnonce    string
    47  	Path      string
    48  	Nc        int16
    49  	Username  string
    50  	Password  string
    51  }
    52  
    53  func (p *myjar) SetCookies(u *url.URL, cookies []*http.Cookie) {
    54  	p.jar[u.Host] = cookies
    55  }
    56  
    57  func (p *myjar) Cookies(u *url.URL) []*http.Cookie {
    58  	return p.jar[u.Host]
    59  }
    60  
    61  func (d *DigestHeaders) digestChecksum() {
    62  	switch d.Algorithm {
    63  	case "MD5":
    64  		// A1
    65  		h := md5.New()
    66  		A1 := fmt.Sprintf("%s:%s:%s", d.Username, d.Realm, d.Password)
    67  		io.WriteString(h, A1)
    68  		d.HA1 = fmt.Sprintf("%x", h.Sum(nil))
    69  
    70  		// A2
    71  		h = md5.New()
    72  		A2 := fmt.Sprintf("%s:%s", d.Method, d.Path)
    73  		io.WriteString(h, A2)
    74  		d.HA2 = fmt.Sprintf("%x", h.Sum(nil))
    75  	case "MD5-sess":
    76  		// A1
    77  		h := md5.New()
    78  		A1 := fmt.Sprintf("%s:%s:%s", d.Username, d.Realm, d.Password)
    79  		io.WriteString(h, A1)
    80  		haPre := fmt.Sprintf("%x", h.Sum(nil))
    81  		h = md5.New()
    82  		A1 = fmt.Sprintf("%s:%s:%s", haPre, d.Nonce, d.Cnonce)
    83  		io.WriteString(h, A1)
    84  		d.HA1 = fmt.Sprintf("%x", h.Sum(nil))
    85  
    86  		// A2
    87  		h = md5.New()
    88  		A2 := fmt.Sprintf("%s:%s", d.Method, d.Path)
    89  		io.WriteString(h, A2)
    90  		d.HA2 = fmt.Sprintf("%x", h.Sum(nil))
    91  	default:
    92  		//token
    93  	}
    94  }
    95  
    96  // ApplyAuth adds proper auth header to the passed request
    97  func (d *DigestHeaders) ApplyAuth(req *http.Request) {
    98  	d.Nc += 0x1
    99  	d.Cnonce = randomKey()
   100  	d.Method = req.Method
   101  	d.Path = req.URL.RequestURI()
   102  	d.digestChecksum()
   103  	response := h(strings.Join([]string{d.HA1, d.Nonce, fmt.Sprintf("%08x", d.Nc),
   104  		d.Cnonce, d.Qop, d.HA2}, ":"))
   105  	AuthHeader := fmt.Sprintf(`Digest username="%s", realm="%s", nonce="%s", uri="%s", cnonce="%s", nc=%08x, qop=%s, response="%s", algorithm=%s`,
   106  		d.Username, d.Realm, d.Nonce, d.Path, d.Cnonce, d.Nc, d.Qop, response, d.Algorithm)
   107  	if d.Opaque != "" {
   108  		AuthHeader = fmt.Sprintf(`%s, opaque="%s"`, AuthHeader, d.Opaque)
   109  	}
   110  	req.Header.Set("Authorization", AuthHeader)
   111  }
   112  
   113  // Auth authenticates against a given URI
   114  func (d *DigestHeaders) Auth(username string, password string, uri string) (*DigestHeaders, error) {
   115  
   116  	client := &http.Client{}
   117  	jar := &myjar{}
   118  	jar.jar = make(map[string][]*http.Cookie)
   119  	client.Jar = jar
   120  
   121  	req, err := http.NewRequest("GET", uri, nil)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	resp, err := client.Do(req)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	if resp.StatusCode == 401 {
   131  
   132  		authn := digestAuthParams(resp)
   133  		algorithm := authn["algorithm"]
   134  		d := &DigestHeaders{}
   135  		u, _ := url.Parse(uri)
   136  		d.Path = u.RequestURI()
   137  		d.Realm = authn["realm"]
   138  		d.Qop = authn["qop"]
   139  		d.Nonce = authn["nonce"]
   140  		d.Opaque = authn["opaque"]
   141  		if algorithm == "" {
   142  			d.Algorithm = "MD5"
   143  		} else {
   144  			d.Algorithm = authn["algorithm"]
   145  		}
   146  		d.Nc = 0x0
   147  		d.Username = username
   148  		d.Password = password
   149  
   150  		req, _ = http.NewRequest("GET", uri, nil)
   151  		d.ApplyAuth(req)
   152  		resp, err = client.Do(req)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  		if resp.StatusCode != 200 {
   157  			d = &DigestHeaders{}
   158  			err = fmt.Errorf("response status code was %v", resp.StatusCode)
   159  		}
   160  		return d, err
   161  	}
   162  	return nil, fmt.Errorf("response status code should have been 401, it was %v", resp.StatusCode)
   163  }
   164  
   165  /*
   166  Parse Authorization header from the http.Request. Returns a map of
   167  auth parameters or nil if the header is not a valid parsable Digest
   168  auth header.
   169  */
   170  func digestAuthParams(r *http.Response) map[string]string {
   171  	s := strings.SplitN(r.Header.Get("Www-Authenticate"), " ", 2)
   172  	if len(s) != 2 || s[0] != "Digest" {
   173  		return nil
   174  	}
   175  
   176  	result := map[string]string{}
   177  	for _, kv := range strings.Split(s[1], ",") {
   178  		parts := strings.SplitN(kv, "=", 2)
   179  		if len(parts) != 2 {
   180  			continue
   181  		}
   182  		result[strings.Trim(parts[0], "\" ")] = strings.Trim(parts[1], "\" ")
   183  	}
   184  	return result
   185  }
   186  
   187  func randomKey() string {
   188  	k := make([]byte, 12)
   189  	for bytes := 0; bytes < len(k); {
   190  		n, err := rand.Read(k[bytes:])
   191  		if err != nil {
   192  			panic("rand.Read() failed")
   193  		}
   194  		bytes += n
   195  	}
   196  	return base64.StdEncoding.EncodeToString(k)
   197  }
   198  
   199  /*
   200  H function for MD5 algorithm (returns a lower-case hex MD5 digest)
   201  */
   202  func h(data string) string {
   203  	digest := md5.New()
   204  	digest.Write([]byte(data))
   205  	return fmt.Sprintf("%x", digest.Sum(nil))
   206  }