github.com/glycerine/xcryptossh@v7.0.4+incompatible/agent/forward.go (about) 1 // Copyright 2014 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package agent 6 7 import ( 8 "context" 9 "errors" 10 "io" 11 "net" 12 "sync" 13 14 ssh "github.com/glycerine/xcryptossh" 15 ) 16 17 // RequestAgentForwarding sets up agent forwarding for the session. 18 // ForwardToAgent or ForwardToRemote should be called to route 19 // the authentication requests. 20 func RequestAgentForwarding(session *ssh.Session) error { 21 ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) 22 if err != nil { 23 return err 24 } 25 if !ok { 26 return errors.New("forwarding request denied") 27 } 28 return nil 29 } 30 31 // ForwardToAgent routes authentication requests to the given keyring. 32 func ForwardToAgent(ctx context.Context, client *ssh.Client, keyring Agent) error { 33 channels := client.HandleChannelOpen(channelType) 34 if channels == nil { 35 return errors.New("agent: already have handler for " + channelType) 36 } 37 38 go func() { 39 for ch := range channels { 40 channel, reqs, err := ch.Accept() 41 if err != nil { 42 continue 43 } 44 go ssh.DiscardRequests(ctx, reqs, client.Halt) 45 go func() { 46 ServeAgent(keyring, channel) 47 channel.Close() 48 }() 49 } 50 }() 51 return nil 52 } 53 54 const channelType = "auth-agent@openssh.com" 55 56 // ForwardToRemote routes authentication requests to the ssh-agent 57 // process serving on the given unix socket. 58 func ForwardToRemote(ctx context.Context, client *ssh.Client, addr string) error { 59 channels := client.HandleChannelOpen(channelType) 60 if channels == nil { 61 return errors.New("agent: already have handler for " + channelType) 62 } 63 conn, err := net.Dial("unix", addr) 64 if err != nil { 65 return err 66 } 67 conn.Close() 68 69 go func() { 70 for ch := range channels { 71 channel, reqs, err := ch.Accept() 72 if err != nil { 73 continue 74 } 75 go ssh.DiscardRequests(ctx, reqs, client.Halt) 76 go forwardUnixSocket(channel, addr) 77 } 78 }() 79 return nil 80 } 81 82 func forwardUnixSocket(channel ssh.Channel, addr string) { 83 conn, err := net.Dial("unix", addr) 84 if err != nil { 85 return 86 } 87 88 var wg sync.WaitGroup 89 wg.Add(2) 90 go func() { 91 io.Copy(conn, channel) 92 conn.(*net.UnixConn).CloseWrite() 93 wg.Done() 94 }() 95 go func() { 96 io.Copy(channel, conn) 97 channel.CloseWrite() 98 wg.Done() 99 }() 100 101 wg.Wait() 102 conn.Close() 103 channel.Close() 104 }