gitlab.com/apertussolutions/u-root@v7.0.0+incompatible/cmds/core/scp/scp_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"os"
    13  	"path"
    14  	"testing"
    15  )
    16  
    17  func TestScpSource(t *testing.T) {
    18  	var w bytes.Buffer
    19  	var r bytes.Buffer
    20  
    21  	tf, err := ioutil.TempFile("", "TestScpSource")
    22  	if err != nil {
    23  		t.Fatalf("creating temp file: %v", err)
    24  	}
    25  	defer os.Remove(tf.Name())
    26  	tf.Write([]byte("test-file-contents"))
    27  
    28  	r.Write([]byte{0})
    29  	err = scpSource(&w, &r, tf.Name())
    30  	if err != nil {
    31  		t.Fatalf("error: %v", err)
    32  	}
    33  	expected := []byte(fmt.Sprintf("C0600 18 %s\ntest-file-contents", path.Base(tf.Name())))
    34  	expected = append(expected, 0)
    35  	if string(expected) != w.String() {
    36  		t.Fatalf("Got: %v\nExpected: %v", w.String(), string(expected))
    37  	}
    38  }
    39  
    40  func TestScpSink(t *testing.T) {
    41  	var w bytes.Buffer
    42  	var r bytes.Buffer
    43  
    44  	tf, err := ioutil.TempFile("", "TestScpSink")
    45  	if err != nil {
    46  		t.Fatalf("creating temp file: %v", err)
    47  	}
    48  	defer os.Remove(tf.Name())
    49  
    50  	r.Write([]byte(fmt.Sprintf("C0600 18 test\ntest-file-contents")))
    51  	// Post IO-copy success status
    52  	r.Write([]byte{0})
    53  
    54  	err = scpSink(&w, &r, tf.Name())
    55  	if err != nil {
    56  		t.Fatalf("error: %v", err)
    57  	}
    58  
    59  	// 1: Initial SUCCESS post to start scp
    60  	// 2: Success opening file tf.Name()
    61  	// 3: Success writing file
    62  	expected := []byte{0, 0, 0}
    63  	if string(expected) != w.String() {
    64  		t.Fatalf("Got: %v\nExpected: %v", w.Bytes(), expected)
    65  	}
    66  
    67  	m := make([]byte, 18)
    68  	n, err := tf.Read(m)
    69  	if err != nil {
    70  		t.Fatalf("IO error: %v", err)
    71  	}
    72  	if n != 18 {
    73  		t.Fatalf("Expected 18 bytes, got %v", n)
    74  	}
    75  
    76  	// Ensure EOF
    77  	_, err = tf.Read(m)
    78  	if err != io.EOF {
    79  		t.Fatalf("Expected EOF, got %v", err)
    80  	}
    81  
    82  	if string(m) != "test-file-contents" {
    83  		t.Fatalf("Expected 'test-file-contents', got '%v'", string(m))
    84  	}
    85  }