github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/cmd/auth/authtest/authtest.go (about) 1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // authtest is a diagnostic tool for implementations of the GOAUTH protocol 6 // described in https://golang.org/issue/26232. 7 // 8 // It accepts a single URL as an argument, and executes the GOAUTH protocol to 9 // fetch and display the headers for that URL. 10 // 11 // CAUTION: authtest logs the GOAUTH responses, which may include user 12 // credentials, to stderr. Do not post its output unless you are certain that 13 // all of the credentials involved are fake! 14 package main 15 16 import ( 17 "bufio" 18 "bytes" 19 "flag" 20 "fmt" 21 exec "golang.org/x/sys/execabs" 22 "io" 23 "log" 24 "net/http" 25 "net/textproto" 26 "net/url" 27 "os" 28 "path/filepath" 29 "strings" 30 ) 31 32 var v = flag.Bool("v", false, "if true, log GOAUTH responses to stderr") 33 34 func main() { 35 log.SetFlags(log.LstdFlags | log.Lshortfile) 36 flag.Parse() 37 args := flag.Args() 38 if len(args) != 1 { 39 log.Fatalf("usage: [GOAUTH=CMD...] %s URL", filepath.Base(os.Args[0])) 40 } 41 42 resp := try(args[0], nil) 43 if resp.StatusCode == http.StatusOK { 44 return 45 } 46 47 resp = try(args[0], resp) 48 if resp.StatusCode != http.StatusOK { 49 os.Exit(1) 50 } 51 } 52 53 func try(url string, prev *http.Response) *http.Response { 54 req := new(http.Request) 55 if prev != nil { 56 *req = *prev.Request 57 } else { 58 var err error 59 req, err = http.NewRequest("HEAD", os.Args[1], nil) 60 if err != nil { 61 log.Fatal(err) 62 } 63 } 64 65 goauth: 66 for _, argList := range strings.Split(os.Getenv("GOAUTH"), ";") { 67 // TODO(golang.org/issue/26849): If we escape quoted strings in GOFLAGS, use 68 // the same quoting here. 69 args := strings.Split(argList, " ") 70 if len(args) == 0 || args[0] == "" { 71 log.Fatalf("invalid or empty command in GOAUTH") 72 } 73 74 creds, err := getCreds(args, prev) 75 if err != nil { 76 log.Fatal(err) 77 } 78 for _, c := range creds { 79 if c.Apply(req) { 80 fmt.Fprintf(os.Stderr, "# request to %s\n", req.URL) 81 fmt.Fprintf(os.Stderr, "%s %s %s\n", req.Method, req.URL, req.Proto) 82 req.Header.Write(os.Stderr) 83 fmt.Fprintln(os.Stderr) 84 break goauth 85 } 86 } 87 } 88 89 resp, err := http.DefaultClient.Do(req) 90 if err != nil { 91 log.Fatal(err) 92 } 93 defer resp.Body.Close() 94 95 if resp.StatusCode != http.StatusOK && resp.StatusCode < 400 || resp.StatusCode > 500 { 96 log.Fatalf("unexpected status: %v", resp.Status) 97 } 98 99 fmt.Fprintf(os.Stderr, "# response from %s\n", resp.Request.URL) 100 formatHead(os.Stderr, resp) 101 return resp 102 } 103 104 func formatHead(out io.Writer, resp *http.Response) { 105 fmt.Fprintf(out, "%s %s\n", resp.Proto, resp.Status) 106 if err := resp.Header.Write(out); err != nil { 107 log.Fatal(err) 108 } 109 fmt.Fprintln(out) 110 } 111 112 type Cred struct { 113 URLPrefixes []*url.URL 114 Header http.Header 115 } 116 117 func (c Cred) Apply(req *http.Request) bool { 118 if req.URL == nil { 119 return false 120 } 121 ok := false 122 for _, prefix := range c.URLPrefixes { 123 if prefix.Host == req.URL.Host && 124 (req.URL.Path == prefix.Path || 125 (strings.HasPrefix(req.URL.Path, prefix.Path) && 126 (strings.HasSuffix(prefix.Path, "/") || 127 req.URL.Path[len(prefix.Path)] == '/'))) { 128 ok = true 129 break 130 } 131 } 132 if !ok { 133 return false 134 } 135 136 for k, vs := range c.Header { 137 req.Header.Del(k) 138 for _, v := range vs { 139 req.Header.Add(k, v) 140 } 141 } 142 return true 143 } 144 145 func (c Cred) String() string { 146 var buf strings.Builder 147 for _, u := range c.URLPrefixes { 148 fmt.Fprintln(&buf, u) 149 } 150 buf.WriteString("\n") 151 c.Header.Write(&buf) 152 buf.WriteString("\n") 153 return buf.String() 154 } 155 156 func getCreds(args []string, resp *http.Response) ([]Cred, error) { 157 cmd := exec.Command(args[0], args[1:]...) 158 cmd.Stderr = os.Stderr 159 160 if resp != nil { 161 u := *resp.Request.URL 162 u.RawQuery = "" 163 cmd.Args = append(cmd.Args, u.String()) 164 } 165 166 var head strings.Builder 167 if resp != nil { 168 formatHead(&head, resp) 169 } 170 cmd.Stdin = strings.NewReader(head.String()) 171 172 fmt.Fprintf(os.Stderr, "# %s\n", strings.Join(cmd.Args, " ")) 173 out, err := cmd.Output() 174 if err != nil { 175 return nil, fmt.Errorf("%s: %v", strings.Join(cmd.Args, " "), err) 176 } 177 os.Stderr.Write(out) 178 os.Stderr.WriteString("\n") 179 180 var creds []Cred 181 r := textproto.NewReader(bufio.NewReader(bytes.NewReader(out))) 182 line := 0 183 readLoop: 184 for { 185 var prefixes []*url.URL 186 for { 187 prefix, err := r.ReadLine() 188 if err == io.EOF { 189 if len(prefixes) > 0 { 190 return nil, fmt.Errorf("line %d: %v", line, io.ErrUnexpectedEOF) 191 } 192 break readLoop 193 } 194 line++ 195 196 if prefix == "" { 197 if len(prefixes) == 0 { 198 return nil, fmt.Errorf("line %d: unexpected newline", line) 199 } 200 break 201 } 202 u, err := url.Parse(prefix) 203 if err != nil { 204 return nil, fmt.Errorf("line %d: malformed URL: %v", line, err) 205 } 206 if u.Scheme != "https" { 207 return nil, fmt.Errorf("line %d: non-HTTPS URL %q", line, prefix) 208 } 209 if len(u.RawQuery) > 0 { 210 return nil, fmt.Errorf("line %d: unexpected query string in URL %q", line, prefix) 211 } 212 if len(u.Fragment) > 0 { 213 return nil, fmt.Errorf("line %d: unexpected fragment in URL %q", line, prefix) 214 } 215 prefixes = append(prefixes, u) 216 } 217 218 header, err := r.ReadMIMEHeader() 219 if err != nil { 220 return nil, fmt.Errorf("headers at line %d: %v", line, err) 221 } 222 if len(header) > 0 { 223 creds = append(creds, Cred{ 224 URLPrefixes: prefixes, 225 Header: http.Header(header), 226 }) 227 } 228 } 229 230 return creds, nil 231 }