cuelabs.dev/go/oci/ociregistry@v0.0.0-20240906074133-82eb438dd565/ociserver/error_test.go (about) 1 // Copyright 2023 CUE Labs AG 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ociserver 16 17 import ( 18 "context" 19 "io" 20 "net/http" 21 "net/http/httptest" 22 "testing" 23 24 "cuelabs.dev/go/oci/ociregistry" 25 26 "github.com/go-quicktest/qt" 27 ) 28 29 func TestCustomErrorWriter(t *testing.T) { 30 // Test that if an Interface method returns an HTTPError error, the 31 // HTTP status code is derived from the OCI error code in preference 32 // to the HTTPError status code. 33 r := New(&ociregistry.Funcs{}, &Options{ 34 WriteError: func(w http.ResponseWriter, _ *http.Request, err error) { 35 w.Header().Set("Some-Header", "a value") 36 ociregistry.WriteError(w, err) 37 }, 38 }) 39 s := httptest.NewServer(r) 40 defer s.Close() 41 resp, err := http.Get(s.URL + "/v2/foo/manifests/sometag") 42 qt.Assert(t, qt.IsNil(err)) 43 defer resp.Body.Close() 44 qt.Assert(t, qt.Equals(resp.Header.Get("Some-Header"), "a value")) 45 } 46 47 func TestHTTPStatusOverriddenByErrorCode(t *testing.T) { 48 // Test that if an Interface method returns an HTTPError error, the 49 // HTTP status code is derived from the OCI error code in preference 50 // to the HTTPError status code. 51 r := New(&ociregistry.Funcs{ 52 GetTag_: func(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) { 53 return nil, ociregistry.NewHTTPError(ociregistry.ErrNameUnknown, http.StatusUnauthorized, nil, nil) 54 }, 55 }, nil) 56 s := httptest.NewServer(r) 57 defer s.Close() 58 resp, err := http.Get(s.URL + "/v2/foo/manifests/sometag") 59 qt.Assert(t, qt.IsNil(err)) 60 defer resp.Body.Close() 61 body, _ := io.ReadAll(resp.Body) 62 qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusNotFound)) 63 qt.Assert(t, qt.JSONEquals(body, &ociregistry.WireErrors{ 64 Errors: []ociregistry.WireError{{ 65 Code_: ociregistry.ErrNameUnknown.Code(), 66 Message: "401 Unauthorized: name unknown: repository name not known to registry", 67 }}, 68 })) 69 } 70 71 func TestHTTPStatusUsedForUnknownErrorCode(t *testing.T) { 72 // Test that if an Interface method returns an HTTPError error, that 73 // HTTP status code is used when the code isn't known to be 74 // associated with a particular HTTP status. 75 r := New(&ociregistry.Funcs{ 76 GetTag_: func(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) { 77 return nil, ociregistry.NewHTTPError(ociregistry.NewError("foo", "SOMECODE", nil), http.StatusTeapot, nil, nil) 78 }, 79 }, nil) 80 s := httptest.NewServer(r) 81 defer s.Close() 82 resp, err := http.Get(s.URL + "/v2/foo/manifests/sometag") 83 qt.Assert(t, qt.IsNil(err)) 84 defer resp.Body.Close() 85 body, _ := io.ReadAll(resp.Body) 86 qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusTeapot)) 87 qt.Assert(t, qt.JSONEquals(body, &ociregistry.WireErrors{ 88 Errors: []ociregistry.WireError{{ 89 Code_: "SOMECODE", 90 Message: "foo", 91 }}, 92 })) 93 }