github.com/rish1988/moby@v25.0.2+incompatible/client/image_pull_test.go (about)

     1  package client // import "github.com/docker/docker/client"
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/docker/docker/api/types/image"
    13  	"github.com/docker/docker/api/types/registry"
    14  	"github.com/docker/docker/errdefs"
    15  	"gotest.tools/v3/assert"
    16  	is "gotest.tools/v3/assert/cmp"
    17  )
    18  
    19  func TestImagePullReferenceParseError(t *testing.T) {
    20  	client := &Client{
    21  		client: newMockClient(func(req *http.Request) (*http.Response, error) {
    22  			return nil, nil
    23  		}),
    24  	}
    25  	// An empty reference is an invalid reference
    26  	_, err := client.ImagePull(context.Background(), "", image.PullOptions{})
    27  	if err == nil || !strings.Contains(err.Error(), "invalid reference format") {
    28  		t.Fatalf("expected an error, got %v", err)
    29  	}
    30  }
    31  
    32  func TestImagePullAnyError(t *testing.T) {
    33  	client := &Client{
    34  		client: newMockClient(errorMock(http.StatusInternalServerError, "Server error")),
    35  	}
    36  	_, err := client.ImagePull(context.Background(), "myimage", image.PullOptions{})
    37  	assert.Check(t, is.ErrorType(err, errdefs.IsSystem))
    38  }
    39  
    40  func TestImagePullStatusUnauthorizedError(t *testing.T) {
    41  	client := &Client{
    42  		client: newMockClient(errorMock(http.StatusUnauthorized, "Unauthorized error")),
    43  	}
    44  	_, err := client.ImagePull(context.Background(), "myimage", image.PullOptions{})
    45  	assert.Check(t, is.ErrorType(err, errdefs.IsUnauthorized))
    46  }
    47  
    48  func TestImagePullWithUnauthorizedErrorAndPrivilegeFuncError(t *testing.T) {
    49  	client := &Client{
    50  		client: newMockClient(errorMock(http.StatusUnauthorized, "Unauthorized error")),
    51  	}
    52  	privilegeFunc := func() (string, error) {
    53  		return "", fmt.Errorf("Error requesting privilege")
    54  	}
    55  	_, err := client.ImagePull(context.Background(), "myimage", image.PullOptions{
    56  		PrivilegeFunc: privilegeFunc,
    57  	})
    58  	if err == nil || err.Error() != "Error requesting privilege" {
    59  		t.Fatalf("expected an error requesting privilege, got %v", err)
    60  	}
    61  }
    62  
    63  func TestImagePullWithUnauthorizedErrorAndAnotherUnauthorizedError(t *testing.T) {
    64  	client := &Client{
    65  		client: newMockClient(errorMock(http.StatusUnauthorized, "Unauthorized error")),
    66  	}
    67  	privilegeFunc := func() (string, error) {
    68  		return "a-auth-header", nil
    69  	}
    70  	_, err := client.ImagePull(context.Background(), "myimage", image.PullOptions{
    71  		PrivilegeFunc: privilegeFunc,
    72  	})
    73  	assert.Check(t, is.ErrorType(err, errdefs.IsUnauthorized))
    74  }
    75  
    76  func TestImagePullWithPrivilegedFuncNoError(t *testing.T) {
    77  	expectedURL := "/images/create"
    78  	client := &Client{
    79  		client: newMockClient(func(req *http.Request) (*http.Response, error) {
    80  			if !strings.HasPrefix(req.URL.Path, expectedURL) {
    81  				return nil, fmt.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL)
    82  			}
    83  			auth := req.Header.Get(registry.AuthHeader)
    84  			if auth == "NotValid" {
    85  				return &http.Response{
    86  					StatusCode: http.StatusUnauthorized,
    87  					Body:       io.NopCloser(bytes.NewReader([]byte("Invalid credentials"))),
    88  				}, nil
    89  			}
    90  			if auth != "IAmValid" {
    91  				return nil, fmt.Errorf("invalid auth header: expected %s, got %s", "IAmValid", auth)
    92  			}
    93  			query := req.URL.Query()
    94  			fromImage := query.Get("fromImage")
    95  			if fromImage != "myimage" {
    96  				return nil, fmt.Errorf("fromimage not set in URL query properly. Expected '%s', got %s", "myimage", fromImage)
    97  			}
    98  			tag := query.Get("tag")
    99  			if tag != "latest" {
   100  				return nil, fmt.Errorf("tag not set in URL query properly. Expected '%s', got %s", "latest", tag)
   101  			}
   102  			return &http.Response{
   103  				StatusCode: http.StatusOK,
   104  				Body:       io.NopCloser(bytes.NewReader([]byte("hello world"))),
   105  			}, nil
   106  		}),
   107  	}
   108  	privilegeFunc := func() (string, error) {
   109  		return "IAmValid", nil
   110  	}
   111  	resp, err := client.ImagePull(context.Background(), "myimage", image.PullOptions{
   112  		RegistryAuth:  "NotValid",
   113  		PrivilegeFunc: privilegeFunc,
   114  	})
   115  	if err != nil {
   116  		t.Fatal(err)
   117  	}
   118  	body, err := io.ReadAll(resp)
   119  	if err != nil {
   120  		t.Fatal(err)
   121  	}
   122  	if string(body) != "hello world" {
   123  		t.Fatalf("expected 'hello world', got %s", string(body))
   124  	}
   125  }
   126  
   127  func TestImagePullWithoutErrors(t *testing.T) {
   128  	expectedURL := "/images/create"
   129  	expectedOutput := "hello world"
   130  	pullCases := []struct {
   131  		all           bool
   132  		reference     string
   133  		expectedImage string
   134  		expectedTag   string
   135  	}{
   136  		{
   137  			all:           false,
   138  			reference:     "myimage",
   139  			expectedImage: "myimage",
   140  			expectedTag:   "latest",
   141  		},
   142  		{
   143  			all:           false,
   144  			reference:     "myimage:tag",
   145  			expectedImage: "myimage",
   146  			expectedTag:   "tag",
   147  		},
   148  		{
   149  			all:           true,
   150  			reference:     "myimage",
   151  			expectedImage: "myimage",
   152  			expectedTag:   "",
   153  		},
   154  		{
   155  			all:           true,
   156  			reference:     "myimage:anything",
   157  			expectedImage: "myimage",
   158  			expectedTag:   "",
   159  		},
   160  	}
   161  	for _, pullCase := range pullCases {
   162  		client := &Client{
   163  			client: newMockClient(func(req *http.Request) (*http.Response, error) {
   164  				if !strings.HasPrefix(req.URL.Path, expectedURL) {
   165  					return nil, fmt.Errorf("Expected URL '%s', got '%s'", expectedURL, req.URL)
   166  				}
   167  				query := req.URL.Query()
   168  				fromImage := query.Get("fromImage")
   169  				if fromImage != pullCase.expectedImage {
   170  					return nil, fmt.Errorf("fromimage not set in URL query properly. Expected '%s', got %s", pullCase.expectedImage, fromImage)
   171  				}
   172  				tag := query.Get("tag")
   173  				if tag != pullCase.expectedTag {
   174  					return nil, fmt.Errorf("tag not set in URL query properly. Expected '%s', got %s", pullCase.expectedTag, tag)
   175  				}
   176  				return &http.Response{
   177  					StatusCode: http.StatusOK,
   178  					Body:       io.NopCloser(bytes.NewReader([]byte(expectedOutput))),
   179  				}, nil
   180  			}),
   181  		}
   182  		resp, err := client.ImagePull(context.Background(), pullCase.reference, image.PullOptions{
   183  			All: pullCase.all,
   184  		})
   185  		if err != nil {
   186  			t.Fatal(err)
   187  		}
   188  		body, err := io.ReadAll(resp)
   189  		if err != nil {
   190  			t.Fatal(err)
   191  		}
   192  		if string(body) != expectedOutput {
   193  			t.Fatalf("expected '%s', got %s", expectedOutput, string(body))
   194  		}
   195  	}
   196  }