github.com/pkg/sftp@v1.13.6/server_test.go (about) 1 package sftp 2 3 import ( 4 "bytes" 5 "errors" 6 "io" 7 "os" 8 "path" 9 "runtime" 10 "sync" 11 "syscall" 12 "testing" 13 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 ) 17 18 func clientServerPair(t *testing.T) (*Client, *Server) { 19 cr, sw := io.Pipe() 20 sr, cw := io.Pipe() 21 var options []ServerOption 22 if *testAllocator { 23 options = append(options, WithAllocator()) 24 } 25 server, err := NewServer(struct { 26 io.Reader 27 io.WriteCloser 28 }{sr, sw}, options...) 29 if err != nil { 30 t.Fatal(err) 31 } 32 go server.Serve() 33 client, err := NewClientPipe(cr, cw) 34 if err != nil { 35 t.Fatalf("%+v\n", err) 36 } 37 return client, server 38 } 39 40 type sshFxpTestBadExtendedPacket struct { 41 ID uint32 42 Extension string 43 Data string 44 } 45 46 func (p sshFxpTestBadExtendedPacket) id() uint32 { return p.ID } 47 48 func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) { 49 l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 50 4 + len(p.Extension) + 51 4 + len(p.Data) 52 53 b := make([]byte, 4, l) 54 b = append(b, sshFxpExtended) 55 b = marshalUint32(b, p.ID) 56 b = marshalString(b, p.Extension) 57 b = marshalString(b, p.Data) 58 59 return b, nil 60 } 61 62 func checkServerAllocator(t *testing.T, server *Server) { 63 if server.pktMgr.alloc == nil { 64 return 65 } 66 checkAllocatorBeforeServerClose(t, server.pktMgr.alloc) 67 server.Close() 68 checkAllocatorAfterServerClose(t, server.pktMgr.alloc) 69 } 70 71 // test that errors are sent back when we request an invalid extended packet operation 72 // this validates the following rfc draft is followed https://tools.ietf.org/html/draft-ietf-secsh-filexfer-extensions-00 73 func TestInvalidExtendedPacket(t *testing.T) { 74 client, server := clientServerPair(t) 75 defer client.Close() 76 defer server.Close() 77 78 badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"} 79 typ, data, err := client.clientConn.sendPacket(nil, badPacket) 80 if err != nil { 81 t.Fatalf("unexpected error from sendPacket: %s", err) 82 } 83 if typ != sshFxpStatus { 84 t.Fatalf("received non-FPX_STATUS packet: %v", typ) 85 } 86 87 err = unmarshalStatus(badPacket.id(), data) 88 statusErr, ok := err.(*StatusError) 89 if !ok { 90 t.Fatal("failed to convert error from unmarshalStatus to *StatusError") 91 } 92 if statusErr.Code != sshFxOPUnsupported { 93 t.Errorf("statusErr.Code => %d, wanted %d", statusErr.Code, sshFxOPUnsupported) 94 } 95 checkServerAllocator(t, server) 96 } 97 98 // test that server handles concurrent requests correctly 99 func TestConcurrentRequests(t *testing.T) { 100 skipIfWindows(t) 101 filename := "/etc/passwd" 102 if runtime.GOOS == "plan9" { 103 filename = "/lib/ndb/local" 104 } 105 client, server := clientServerPair(t) 106 defer client.Close() 107 defer server.Close() 108 109 concurrency := 2 110 var wg sync.WaitGroup 111 wg.Add(concurrency) 112 113 for i := 0; i < concurrency; i++ { 114 go func() { 115 defer wg.Done() 116 117 for j := 0; j < 1024; j++ { 118 f, err := client.Open(filename) 119 if err != nil { 120 t.Errorf("failed to open file: %v", err) 121 continue 122 } 123 if err := f.Close(); err != nil { 124 t.Errorf("failed t close file: %v", err) 125 } 126 } 127 }() 128 } 129 wg.Wait() 130 checkServerAllocator(t, server) 131 } 132 133 // Test error conversion 134 func TestStatusFromError(t *testing.T) { 135 type test struct { 136 err error 137 pkt *sshFxpStatusPacket 138 } 139 tpkt := func(id, code uint32) *sshFxpStatusPacket { 140 return &sshFxpStatusPacket{ 141 ID: id, 142 StatusError: StatusError{Code: code}, 143 } 144 } 145 testCases := []test{ 146 {syscall.ENOENT, tpkt(1, sshFxNoSuchFile)}, 147 {&os.PathError{Err: syscall.ENOENT}, 148 tpkt(2, sshFxNoSuchFile)}, 149 {&os.PathError{Err: errors.New("foo")}, tpkt(3, sshFxFailure)}, 150 {ErrSSHFxEOF, tpkt(4, sshFxEOF)}, 151 {ErrSSHFxOpUnsupported, tpkt(5, sshFxOPUnsupported)}, 152 {io.EOF, tpkt(6, sshFxEOF)}, 153 {os.ErrNotExist, tpkt(7, sshFxNoSuchFile)}, 154 } 155 for _, tc := range testCases { 156 tc.pkt.StatusError.msg = tc.err.Error() 157 assert.Equal(t, tc.pkt, statusFromError(tc.pkt.ID, tc.err)) 158 } 159 } 160 161 // This was written to test a race b/w open immediately followed by a stat. 162 // Previous to this the Open would trigger the use of a worker pool, then the 163 // stat packet would come in an hit the pool and return faster than the open 164 // (returning a file-not-found error). 165 // The below by itself wouldn't trigger the race however, I needed to add a 166 // small sleep in the openpacket code to trigger the issue. I wanted to add a 167 // way to inject that in the code but right now there is no good place for it. 168 // I'm thinking after I convert the server into a request-server backend I 169 // might be able to do something with the runWorker method passed into the 170 // packet manager. But with the 2 implementations fo the server it just doesn't 171 // fit well right now. 172 func TestOpenStatRace(t *testing.T) { 173 client, server := clientServerPair(t) 174 defer client.Close() 175 defer server.Close() 176 177 // openpacket finishes to fast to trigger race in tests 178 // need to add a small sleep on server to openpackets somehow 179 tmppath := path.Join(os.TempDir(), "stat_race") 180 pflags := flags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) 181 ch := make(chan result, 3) 182 id1 := client.nextID() 183 client.dispatchRequest(ch, &sshFxpOpenPacket{ 184 ID: id1, 185 Path: tmppath, 186 Pflags: pflags, 187 }) 188 id2 := client.nextID() 189 client.dispatchRequest(ch, &sshFxpLstatPacket{ 190 ID: id2, 191 Path: tmppath, 192 }) 193 testreply := func(id uint32) { 194 r := <-ch 195 switch r.typ { 196 case sshFxpAttrs, sshFxpHandle: // ignore 197 case sshFxpStatus: 198 err := normaliseError(unmarshalStatus(id, r.data)) 199 assert.NoError(t, err, "race hit, stat before open") 200 default: 201 t.Fatal("unexpected type:", r.typ) 202 } 203 } 204 testreply(id1) 205 testreply(id2) 206 os.Remove(tmppath) 207 checkServerAllocator(t, server) 208 } 209 210 // Ensure that proper error codes are returned for non existent files, such 211 // that they are mapped back to a 'not exists' error on the client side. 212 func TestStatNonExistent(t *testing.T) { 213 client, server := clientServerPair(t) 214 defer client.Close() 215 defer server.Close() 216 217 for _, file := range []string{"/doesnotexist", "/doesnotexist/a/b"} { 218 _, err := client.Stat(file) 219 if !os.IsNotExist(err) { 220 t.Errorf("expected 'does not exist' err for file %q. got: %v", file, err) 221 } 222 } 223 } 224 225 func TestServerWithBrokenClient(t *testing.T) { 226 validInit := sp(&sshFxInitPacket{Version: 3}) 227 brokenOpen := sp(&sshFxpOpenPacket{Path: "foo"}) 228 brokenOpen = brokenOpen[:len(brokenOpen)-2] 229 230 for _, clientInput := range [][]byte{ 231 // Packet length zero (never valid). This used to crash the server. 232 {0, 0, 0, 0}, 233 append(validInit, 0, 0, 0, 0), 234 235 // Client hangs up mid-packet. 236 append(validInit, brokenOpen...), 237 } { 238 srv, err := NewServer(struct { 239 io.Reader 240 io.WriteCloser 241 }{ 242 bytes.NewReader(clientInput), 243 &sink{}, 244 }) 245 require.NoError(t, err) 246 247 err = srv.Serve() 248 assert.Error(t, err) 249 srv.Close() 250 } 251 }