github.com/matrixorigin/matrixone@v1.2.0/pkg/util/errutil/context_test.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package errutil
    16  
    17  import (
    18  	"context"
    19  	goErrors "errors"
    20  	"fmt"
    21  	"reflect"
    22  	"testing"
    23  
    24  	"github.com/matrixorigin/matrixone/pkg/util/stack"
    25  )
    26  
    27  var ctx = context.Background()
    28  var testErr = goErrors.New("test error")
    29  var stackErr error = &withStack{cause: testErr, Stack: stack.Callers(1)}
    30  
    31  func TestGetContextTracer(t *testing.T) {
    32  	type args struct {
    33  		err error
    34  	}
    35  	tests := []struct {
    36  		name string
    37  		args args
    38  		want context.Context
    39  	}{
    40  		{
    41  			name: "nil",
    42  			args: args{err: goErrors.New("test error")},
    43  			want: nil,
    44  		},
    45  		{
    46  			name: "context",
    47  			args: args{err: WithContext(context.Background(), goErrors.New("test error"))},
    48  			want: context.Background(),
    49  		},
    50  	}
    51  	for _, tt := range tests {
    52  		t.Run(tt.name, func(t *testing.T) {
    53  			if got := GetContextTracer(tt.args.err); !reflect.DeepEqual(got, tt.want) && !reflect.DeepEqual(got.Context(), tt.want) {
    54  				t.Errorf("GetContextTracer() = %v, want %v", got, tt.want)
    55  			}
    56  		})
    57  	}
    58  }
    59  
    60  func TestHasContext(t *testing.T) {
    61  	type args struct {
    62  		err error
    63  	}
    64  	tests := []struct {
    65  		name string
    66  		args args
    67  		want bool
    68  	}{
    69  		{
    70  			name: "nil",
    71  			args: args{err: goErrors.New("test error")},
    72  			want: false,
    73  		},
    74  		{
    75  			name: "context",
    76  			args: args{err: WithContext(context.Background(), goErrors.New("test error"))},
    77  			want: true,
    78  		},
    79  	}
    80  	for _, tt := range tests {
    81  		t.Run(tt.name, func(t *testing.T) {
    82  			if got := HasContext(tt.args.err); got != tt.want {
    83  				t.Errorf("HasContext() = %v, want %v", got, tt.want)
    84  			}
    85  		})
    86  	}
    87  }
    88  
    89  func Test_withContext_Cause(t *testing.T) {
    90  	type fields struct {
    91  		cause error
    92  		ctx   context.Context
    93  	}
    94  	tests := []struct {
    95  		name    string
    96  		fields  fields
    97  		wantErr error
    98  	}{
    99  		{
   100  			name:    "normal",
   101  			fields:  fields{cause: testErr, ctx: ctx},
   102  			wantErr: testErr,
   103  		},
   104  		{
   105  			name:    "stack",
   106  			fields:  fields{stackErr, ctx},
   107  			wantErr: stackErr,
   108  		},
   109  	}
   110  	for _, tt := range tests {
   111  		t.Run(tt.name, func(t *testing.T) {
   112  			w := &withContext{
   113  				cause: tt.fields.cause,
   114  				ctx:   tt.fields.ctx,
   115  			}
   116  			if got := w.Cause(); !dummyEqualError(got, tt.wantErr) {
   117  				t.Errorf("Cause() error = %v, wantErr %v", got, tt.wantErr)
   118  			}
   119  		})
   120  	}
   121  }
   122  
   123  func Test_withContext_Context(t *testing.T) {
   124  	type fields struct {
   125  		cause error
   126  		ctx   context.Context
   127  	}
   128  	tests := []struct {
   129  		name   string
   130  		fields fields
   131  		want   context.Context
   132  	}{
   133  		{
   134  			name:   "normal",
   135  			fields: fields{cause: testErr, ctx: ctx},
   136  			want:   ctx,
   137  		},
   138  		{
   139  			name:   "stack",
   140  			fields: fields{stackErr, ctx},
   141  			want:   ctx,
   142  		},
   143  	}
   144  	for _, tt := range tests {
   145  		t.Run(tt.name, func(t *testing.T) {
   146  			w := &withContext{
   147  				cause: tt.fields.cause,
   148  				ctx:   tt.fields.ctx,
   149  			}
   150  			if got := w.Context(); !reflect.DeepEqual(got, tt.want) {
   151  				t.Errorf("Context() = %v, want %v", got, tt.want)
   152  			}
   153  		})
   154  	}
   155  }
   156  
   157  func Test_withContext_Error(t *testing.T) {
   158  	type fields struct {
   159  		cause error
   160  		ctx   context.Context
   161  	}
   162  	tests := []struct {
   163  		name   string
   164  		fields fields
   165  		want   string
   166  	}{
   167  		{
   168  			name:   "normal",
   169  			fields: fields{cause: testErr, ctx: ctx},
   170  			want:   "test error",
   171  		},
   172  		{
   173  			name:   "stack",
   174  			fields: fields{stackErr, ctx},
   175  			want:   "test error",
   176  		},
   177  	}
   178  	for _, tt := range tests {
   179  		t.Run(tt.name, func(t *testing.T) {
   180  			w := &withContext{
   181  				cause: tt.fields.cause,
   182  				ctx:   tt.fields.ctx,
   183  			}
   184  			if got := w.Error(); got != tt.want {
   185  				t.Errorf("Error() = %v, want %v", got, tt.want)
   186  			}
   187  		})
   188  	}
   189  }
   190  
   191  func Test_withContext_Unwrap(t *testing.T) {
   192  	type fields struct {
   193  		cause error
   194  		ctx   context.Context
   195  	}
   196  	tests := []struct {
   197  		name    string
   198  		fields  fields
   199  		wantErr error
   200  	}{
   201  		{
   202  			name:    "normal",
   203  			fields:  fields{cause: testErr, ctx: ctx},
   204  			wantErr: testErr,
   205  		},
   206  		{
   207  			name:    "stack",
   208  			fields:  fields{stackErr, ctx},
   209  			wantErr: stackErr,
   210  		},
   211  	}
   212  	for _, tt := range tests {
   213  		t.Run(tt.name, func(t *testing.T) {
   214  			w := &withContext{
   215  				cause: tt.fields.cause,
   216  				ctx:   tt.fields.ctx,
   217  			}
   218  			if got := w.Unwrap(); !dummyEqualError(got, tt.wantErr) {
   219  				t.Errorf("Unwrap() = %v, want %v", got, tt.wantErr)
   220  			}
   221  		})
   222  	}
   223  }
   224  
   225  func TestWithContext(t *testing.T) {
   226  	type args struct {
   227  		ctx context.Context
   228  		err error
   229  	}
   230  	tests := []struct {
   231  		name    string
   232  		args    args
   233  		wantErr bool
   234  	}{
   235  		{
   236  			name:    "normal",
   237  			args:    args{context.Background(), testErr},
   238  			wantErr: true,
   239  		},
   240  		{
   241  			name:    "nil",
   242  			args:    args{},
   243  			wantErr: false,
   244  		},
   245  	}
   246  	for _, tt := range tests {
   247  		t.Run(tt.name, func(t *testing.T) {
   248  			if err := WithContext(tt.args.ctx, tt.args.err); (err != nil) != tt.wantErr {
   249  				t.Errorf("TestWithContext() error = %v, wantErr %v", err, tt.wantErr)
   250  			}
   251  		})
   252  	}
   253  }
   254  
   255  func dummyEqualError(err1, err2 error) bool {
   256  	return fmt.Sprintf("%+v", err1) == fmt.Sprintf("%+v", err2)
   257  }