go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/integration/devshell/server.go (about) 1 // Copyright 2017 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package devshell implements Devshell protocol for locally getting auth token. 16 // 17 // Some Google Cloud tools know how to use it for authentication. 18 package devshell 19 20 import ( 21 "bytes" 22 "context" 23 "encoding/json" 24 "fmt" 25 "net" 26 "strconv" 27 "strings" 28 "sync" 29 30 "golang.org/x/oauth2" 31 32 "go.chromium.org/luci/common/clock" 33 "go.chromium.org/luci/common/logging" 34 "go.chromium.org/luci/common/runtime/paniccatcher" 35 36 "go.chromium.org/luci/auth/integration/internal/localsrv" 37 ) 38 39 // EnvKey is the name of the environment variable which contains the Devshell 40 // server port number which is picked up by Devshell clients. 41 const EnvKey = "DEVSHELL_CLIENT_PORT" 42 43 // Server runs a Devshell server. 44 type Server struct { 45 // Source is used to obtain OAuth2 tokens. 46 Source oauth2.TokenSource 47 // Email is the email associated with the token. 48 Email string 49 // Port is a local TCP port to bind to or 0 to allow the OS to pick one. 50 Port int 51 52 srv localsrv.Server 53 } 54 55 // Start launches background goroutine with the serving loop. 56 // 57 // The provided context is used as base context for request handlers and for 58 // logging. 59 // 60 // The server must be eventually stopped with Stop(). 61 func (s *Server) Start(ctx context.Context) (*net.TCPAddr, error) { 62 return s.srv.Start(ctx, "devshell", s.Port, s.serve) 63 } 64 65 // Stop closes the listening socket, notifies pending requests to abort and 66 // stops the internal serving goroutine. 67 // 68 // Safe to call multiple times. Once stopped, the server cannot be started again 69 // (make a new instance of Server instead). 70 // 71 // Uses the given context for the deadline when waiting for the serving loop 72 // to stop. 73 func (s *Server) Stop(ctx context.Context) error { 74 return s.srv.Stop(ctx) 75 } 76 77 // serve runs the serving loop. 78 func (s *Server) serve(ctx context.Context, l net.Listener, wg *sync.WaitGroup) error { 79 for { 80 conn, err := l.Accept() 81 if err != nil { 82 return err 83 } 84 85 client := &client{ 86 conn: conn, 87 source: s.Source, 88 email: s.Email, 89 ctx: ctx, 90 } 91 92 wg.Add(1) 93 go func() { 94 defer wg.Done() 95 96 paniccatcher.Do(func() { 97 if err := client.handle(); err != nil { 98 logging.Fields{ 99 logging.ErrorKey: err, 100 }.Errorf(client.ctx, "failed to handle client request") 101 } 102 }, func(p *paniccatcher.Panic) { 103 logging.Fields{ 104 "panicReason": p.Reason, 105 }.Errorf(client.ctx, "panic during client handshake:\n%s", p.Stack) 106 }) 107 }() 108 } 109 } 110 111 type client struct { 112 conn net.Conn 113 114 source oauth2.TokenSource 115 email string 116 117 ctx context.Context 118 } 119 120 func (c *client) handle() error { 121 defer c.conn.Close() 122 123 if _, err := c.readRequest(); err != nil { 124 if err := c.sendResponse([]any{err.Error()}); err != nil { 125 return fmt.Errorf("failed to send error: %v", err) 126 } 127 return nil 128 } 129 130 // Get the token. 131 t, err := c.source.Token() 132 if err != nil { 133 if err := c.sendResponse([]any{"cannot get access token"}); err != nil { 134 return fmt.Errorf("failed to send error: %v", err) 135 } 136 return err 137 } 138 139 // Expiration is in seconds from now so compute the correct format. 140 expiry := int(t.Expiry.Sub(clock.Now(c.ctx)).Seconds()) 141 142 return c.sendResponse([]any{c.email, nil, t.AccessToken, expiry}) 143 } 144 145 func (c *client) readRequest() ([]any, error) { 146 header := make([]byte, 6) 147 if _, err := c.conn.Read(header); err != nil { 148 return nil, fmt.Errorf("failed to read the header: %v", err) 149 } 150 151 // The first six bytes contain the length separated by a newline. 152 str := strings.SplitN(string(header), "\n", 2) 153 if len(str) != 2 { 154 return nil, fmt.Errorf("no newline in the first 6 bytes") 155 } 156 157 l, err := strconv.Atoi(str[0]) 158 if err != nil { 159 return nil, fmt.Errorf("length is not a number: %v", err) 160 } 161 162 data := make([]byte, l) 163 copy(data, str[1][:]) 164 165 // Read the rest of the message. 166 if l > len(str[1]) { 167 if _, err := c.conn.Read(data[len(str[1]):]); err != nil { 168 return nil, fmt.Errorf("failed to receive request: %v", err) 169 } 170 } 171 172 // Parse the message to ensure it's a correct JSON. 173 request := []any{} 174 if err := json.Unmarshal(data, &request); err != nil { 175 return nil, fmt.Errorf("failed to deserialize from JSON: %v", err) 176 } 177 178 return request, nil 179 } 180 181 func (c *client) sendResponse(response []any) error { 182 // Encode the response as JSON array (aka JsPbLite format). 183 payload, err := json.Marshal(response) 184 if err != nil { 185 return fmt.Errorf("failed to serialize to JSON: %v", err) 186 } 187 188 var buf bytes.Buffer 189 buf.WriteString(fmt.Sprintf("%d\n", len(payload))) 190 buf.Write(payload) 191 if _, err := c.conn.Write(buf.Bytes()); err != nil { 192 return fmt.Errorf("failed to send response: %v", err) 193 } 194 return nil 195 }