oras.land/oras-go/v2@v2.5.1-0.20240520045656-aef90e4d04c4/internal/httputil/seek_test.go (about)

     1  /*
     2  Copyright The ORAS Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package httputil
    17  
    18  import (
    19  	"bytes"
    20  	"fmt"
    21  	"io"
    22  	"math"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"testing"
    26  )
    27  
    28  func Test_readSeekCloser_Read(t *testing.T) {
    29  	content := []byte("hello world")
    30  	path := "/testpath"
    31  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    32  		if r.URL.Path != path {
    33  			w.WriteHeader(http.StatusNotFound)
    34  			return
    35  		}
    36  		if _, err := w.Write(content); err != nil {
    37  			t.Errorf("failed to write %q: %v", r.URL, err)
    38  		}
    39  	}))
    40  	defer ts.Close()
    41  
    42  	client := ts.Client()
    43  	resp, err := client.Get(ts.URL + path)
    44  	if err != nil {
    45  		t.Fatalf("failed to do request: %v", err)
    46  	}
    47  	rsc := NewReadSeekCloser(client, resp.Request, resp.Body, int64(len(content)))
    48  	buf := bytes.NewBuffer(nil)
    49  	if _, err := buf.ReadFrom(rsc); err != nil {
    50  		t.Errorf("fail to read: %v", err)
    51  	}
    52  	if got := buf.Bytes(); !bytes.Equal(got, content) {
    53  		t.Errorf("readSeekCloser.Read() = %v, want %v", got, content)
    54  	}
    55  	if err := rsc.Close(); err != nil {
    56  		t.Errorf("fail to close: %v", err)
    57  	}
    58  	if !rsc.(*readSeekCloser).closed {
    59  		t.Errorf("readSeekCloser not closed")
    60  	}
    61  }
    62  
    63  func Test_readSeekCloser_Seek(t *testing.T) {
    64  	content := []byte("hello world")
    65  	path := "/testpath"
    66  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    67  		if r.URL.Path != path {
    68  			w.WriteHeader(http.StatusNotFound)
    69  			return
    70  		}
    71  		rangeHeader := r.Header.Get("Range")
    72  		if rangeHeader == "" {
    73  			w.WriteHeader(http.StatusOK)
    74  			if _, err := w.Write(content); err != nil {
    75  				t.Errorf("failed to write %q: %v", r.URL, err)
    76  			}
    77  			return
    78  		}
    79  		var start, end int
    80  		_, err := fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end)
    81  		if err != nil {
    82  			t.Errorf("invalid range header: %s", rangeHeader)
    83  			w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
    84  			return
    85  		}
    86  		if start < 0 || start > end || start >= len(content) {
    87  			t.Errorf("invalid range: %s", rangeHeader)
    88  			w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
    89  			return
    90  		}
    91  		end++
    92  		if end > len(content) {
    93  			end = len(content)
    94  		}
    95  		w.WriteHeader(http.StatusPartialContent)
    96  		if _, err := w.Write(content[start:end]); err != nil {
    97  			t.Errorf("failed to write %q: %v", r.URL, err)
    98  		}
    99  	}))
   100  	defer ts.Close()
   101  
   102  	client := ts.Client()
   103  	resp, err := client.Get(ts.URL + path)
   104  	if err != nil {
   105  		t.Fatalf("failed to do request: %v", err)
   106  	}
   107  	rsc := NewReadSeekCloser(client, resp.Request, resp.Body, int64(len(content)))
   108  
   109  	tests := []struct {
   110  		name       string
   111  		offset     int64
   112  		whence     int
   113  		wantOffset int64
   114  		n          int64
   115  		want       []byte
   116  		skipSeek   bool
   117  	}{
   118  		{
   119  			name:     "read from initial response",
   120  			n:        3,
   121  			want:     []byte("hel"),
   122  			skipSeek: true,
   123  		},
   124  		{
   125  			name:       "seek to skip",
   126  			offset:     2,
   127  			whence:     io.SeekCurrent,
   128  			wantOffset: 5,
   129  			n:          4,
   130  			want:       []byte(" wor"),
   131  		},
   132  		{
   133  			name:       "seek to the beginning",
   134  			offset:     0,
   135  			whence:     io.SeekStart,
   136  			wantOffset: 0,
   137  			n:          5,
   138  			want:       []byte("hello"),
   139  		},
   140  		{
   141  			name:       "seek to middle",
   142  			offset:     6,
   143  			whence:     io.SeekStart,
   144  			wantOffset: 6,
   145  			n:          math.MaxInt64,
   146  			want:       []byte("world"),
   147  		},
   148  		{
   149  			name:       "seek from end",
   150  			offset:     -4,
   151  			whence:     io.SeekEnd,
   152  			wantOffset: 7,
   153  			n:          3,
   154  			want:       []byte("orl"),
   155  		},
   156  		{
   157  			name:       "seek to the end",
   158  			offset:     0,
   159  			whence:     io.SeekEnd,
   160  			wantOffset: 11,
   161  			n:          5,
   162  			want:       nil,
   163  		},
   164  		{
   165  			name:       "seek beyond the end",
   166  			offset:     42,
   167  			whence:     io.SeekStart,
   168  			wantOffset: 42,
   169  			n:          10,
   170  			want:       nil,
   171  		},
   172  	}
   173  	for _, tt := range tests {
   174  		t.Run(tt.name, func(t *testing.T) {
   175  			if !tt.skipSeek {
   176  				got, err := rsc.Seek(tt.offset, tt.whence)
   177  				if err != nil {
   178  					t.Errorf("readSeekCloser.Seek() error = %v", err)
   179  				}
   180  				if got != tt.wantOffset {
   181  					t.Errorf("readSeekCloser.Read() = %v, want %v", got, tt.wantOffset)
   182  				}
   183  			}
   184  			buf := bytes.NewBuffer(nil)
   185  			if _, err := buf.ReadFrom(io.LimitReader(rsc, tt.n)); err != nil {
   186  				t.Errorf("fail to read: %v", err)
   187  			}
   188  			if got := buf.Bytes(); !bytes.Equal(got, tt.want) {
   189  				t.Errorf("readSeekCloser.Read() = %v, want %v", got, tt.want)
   190  			}
   191  		})
   192  	}
   193  
   194  	_, err = rsc.Seek(-1, io.SeekStart)
   195  	if err == nil {
   196  		t.Errorf("readSeekCloser.Seek() error = %v, wantErr %v", err, true)
   197  	}
   198  
   199  	if err := rsc.Close(); err != nil {
   200  		t.Errorf("fail to close: %v", err)
   201  	}
   202  	if !rsc.(*readSeekCloser).closed {
   203  		t.Errorf("readSeekCloser not closed")
   204  	}
   205  
   206  	_, err = rsc.Seek(0, io.SeekStart)
   207  	if err == nil {
   208  		t.Errorf("readSeekCloser.Seek() error = %v, wantErr %v", err, true)
   209  	}
   210  }