github.com/dtroyer-salad/og2/v2@v2.0.0-20240412154159-c47231610877/internal/ioutil/io_test.go (about)

     1  /*
     2  Copyright The ORAS Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package ioutil
    17  
    18  import (
    19  	"bytes"
    20  	"errors"
    21  	"io"
    22  	"os"
    23  	"reflect"
    24  	"strconv"
    25  	"testing"
    26  
    27  	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
    28  	"oras.land/oras-go/v2/content"
    29  	"oras.land/oras-go/v2/internal/spec"
    30  )
    31  
    32  func TestUnwrapNopCloser(t *testing.T) {
    33  	var reader struct {
    34  		io.Reader
    35  	}
    36  	var readerWithWriterTo struct {
    37  		io.Reader
    38  		io.WriterTo
    39  	}
    40  
    41  	tests := []struct {
    42  		name string
    43  		rc   io.Reader
    44  		want io.Reader
    45  	}{
    46  		{
    47  			name: "nil",
    48  		},
    49  		{
    50  			name: "no-op closer with plain io.Reader",
    51  			rc:   io.NopCloser(reader),
    52  			want: reader,
    53  		},
    54  		{
    55  			name: "no-op closer with io.WriteTo",
    56  			rc:   io.NopCloser(readerWithWriterTo),
    57  			want: readerWithWriterTo,
    58  		},
    59  		{
    60  			name: "any ReadCloser",
    61  			rc:   os.Stdin,
    62  			want: os.Stdin,
    63  		},
    64  	}
    65  	for _, tt := range tests {
    66  		t.Run(tt.name, func(t *testing.T) {
    67  			if got := UnwrapNopCloser(tt.rc); !reflect.DeepEqual(got, tt.want) {
    68  				t.Errorf("UnwrapNopCloser() = %v, want %v", got, tt.want)
    69  			}
    70  		})
    71  	}
    72  }
    73  
    74  func TestCopyBuffer(t *testing.T) {
    75  	blob := []byte("foo")
    76  	type args struct {
    77  		src  io.Reader
    78  		buf  []byte
    79  		desc ocispec.Descriptor
    80  	}
    81  	tests := []struct {
    82  		name         string
    83  		args         args
    84  		wantDst      string
    85  		wantErr      error
    86  		blob         []byte
    87  		bufSize      int64
    88  		resumeOffset int64
    89  	}{
    90  		{
    91  			name:    "exact buffer size, no errors",
    92  			wantDst: "foo",
    93  			wantErr: nil,
    94  			blob:    blob,
    95  			bufSize: 3,
    96  		},
    97  		{
    98  			name:         "exact buffer size, no errors, resume",
    99  			wantDst:      "foo",
   100  			wantErr:      nil,
   101  			blob:         blob,
   102  			bufSize:      3,
   103  			resumeOffset: 1,
   104  		},
   105  		{
   106  			name:    "small buffer size, no errors",
   107  			wantDst: "foo",
   108  			wantErr: nil,
   109  			blob:    blob,
   110  			bufSize: 1,
   111  		},
   112  		{
   113  			name:         "small buffer size, no errors, resume",
   114  			wantDst:      "foo",
   115  			wantErr:      nil,
   116  			blob:         blob,
   117  			bufSize:      1,
   118  			resumeOffset: 1,
   119  		},
   120  		{
   121  			name:    "big buffer size, no errors",
   122  			wantDst: "foo",
   123  			wantErr: nil,
   124  			blob:    blob,
   125  			bufSize: 5,
   126  		},
   127  		{
   128  			name:         "big buffer size, no errors, resume",
   129  			wantDst:      "foo",
   130  			wantErr:      nil,
   131  			blob:         blob,
   132  			bufSize:      5,
   133  			resumeOffset: 1,
   134  		},
   135  		{
   136  			name:    "wrong digest",
   137  			wantDst: "foo",
   138  			wantErr: content.ErrMismatchedDigest,
   139  			blob:    []byte("bar"),
   140  			bufSize: 3,
   141  		},
   142  		{
   143  			name:         "wrong digest, resume",
   144  			wantDst:      "foo",
   145  			wantErr:      content.ErrMismatchedDigest,
   146  			blob:         []byte("bar"),
   147  			bufSize:      3,
   148  			resumeOffset: 1,
   149  		},
   150  		{
   151  			name:    "wrong size, descriptor size is smaller",
   152  			wantDst: "foo",
   153  			wantErr: content.ErrTrailingData,
   154  			blob:    []byte("fo"),
   155  			bufSize: 3,
   156  		},
   157  		{
   158  			name:         "wrong size, descriptor size is smaller, resume",
   159  			wantDst:      "foo",
   160  			wantErr:      content.ErrTrailingData,
   161  			blob:         []byte("fo"),
   162  			bufSize:      3,
   163  			resumeOffset: 1,
   164  		},
   165  		{
   166  			name:    "wrong size, descriptor size is larger",
   167  			wantDst: "foo",
   168  			wantErr: io.ErrUnexpectedEOF,
   169  			blob:    []byte("fooo"),
   170  			bufSize: 3,
   171  		},
   172  		// This case is, ugh, the whole point and it isn't verifying!!!
   173  		// {
   174  		// 	name:         "wrong size, descriptor size is larger, resume",
   175  		// 	wantDst:      "foo",
   176  		// 	wantErr:      io.EOF,
   177  		// 	blob:         []byte("fooo"),
   178  		// 	bufSize:      3,
   179  		// 	resumeOffset: 2,
   180  		// },
   181  	}
   182  	for _, tt := range tests {
   183  		t.Run(tt.name, func(t *testing.T) {
   184  			dst := &bytes.Buffer{}
   185  			args := args{
   186  				bytes.NewReader(blob),
   187  				make([]byte, tt.bufSize),
   188  				content.NewDescriptorFromBytes("test", tt.blob),
   189  			}
   190  			if tt.resumeOffset > 0 {
   191  				// Make the starting Hash and run over the starting content
   192  				h := args.desc.Digest.Algorithm().Hash()
   193  				h.Write(tt.blob[0 : tt.resumeOffset-1])
   194  				eh, _ := content.EncodeHash(h)
   195  
   196  				// Add the Annotations for resume
   197  				args.desc.Annotations = map[string]string{
   198  					spec.AnnotationResumeDownload: "true",
   199  					spec.AnnotationResumeOffset:   strconv.FormatInt(tt.resumeOffset, 10),
   200  					spec.AnnotationResumeHash:     eh,
   201  				}
   202  			}
   203  			err := CopyBuffer(dst, args.src, args.buf, args.desc)
   204  			if !errors.Is(err, tt.wantErr) {
   205  				t.Errorf("CopyBuffer() error = %v, wantErr %v", err, tt.wantErr)
   206  				return
   207  			}
   208  			gotDst := dst.String()
   209  			if err == nil && gotDst != tt.wantDst {
   210  				t.Errorf("CopyBuffer() = %v, want %v", gotDst, tt.wantDst)
   211  			}
   212  		})
   213  	}
   214  }