github.com/fastwego/offiaccount@v1.0.1/apis/oauth/oauth_test.go (about)

     1  // Copyright 2020 FastWeGo
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package oauth
    16  
    17  import (
    18  	"fmt"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"os"
    22  	"reflect"
    23  	"testing"
    24  
    25  	"github.com/fastwego/offiaccount"
    26  )
    27  
    28  var MockSvr *httptest.Server
    29  var MockSvrHandler *http.ServeMux
    30  
    31  func TestMain(m *testing.M) {
    32  	// Mock Server
    33  	MockSvrHandler = http.NewServeMux()
    34  	MockSvr = httptest.NewServer(MockSvrHandler)
    35  	offiaccount.WXServerUrl = MockSvr.URL // 拦截发往微信服务器的请求
    36  
    37  	os.Exit(m.Run())
    38  }
    39  
    40  func TestAuth(t *testing.T) {
    41  	// Mock
    42  	MockSvrHandler.HandleFunc(apiAuth, func(w http.ResponseWriter, r *http.Request) {
    43  		_, _ = w.Write([]byte(`{ "errcode":0,"errmsg":"ok"}`))
    44  	})
    45  
    46  	type args struct {
    47  		access_token string
    48  		openid       string
    49  	}
    50  	tests := []struct {
    51  		name        string
    52  		args        args
    53  		wantIsValid bool
    54  		wantErr     bool
    55  	}{
    56  		{name: "case1", args: args{access_token: "", openid: ""}, wantIsValid: true, wantErr: false},
    57  	}
    58  	for _, tt := range tests {
    59  		t.Run(tt.name, func(t *testing.T) {
    60  			gotIsValid, err := Auth(tt.args.access_token, tt.args.openid)
    61  			if (err != nil) != tt.wantErr {
    62  				t.Errorf("Auth() error = %v, wantErr %v", err, tt.wantErr)
    63  				return
    64  			}
    65  			if gotIsValid != tt.wantIsValid {
    66  				t.Errorf("Auth() gotIsValid = %v, want %v", gotIsValid, tt.wantIsValid)
    67  			}
    68  		})
    69  	}
    70  }
    71  
    72  func TestGetAccessToken(t *testing.T) {
    73  	// Mock
    74  	MockSvrHandler.HandleFunc(apiAccessToken, func(w http.ResponseWriter, r *http.Request) {
    75  		_, _ = w.Write([]byte(`{
    76  		  "access_token":"ACCESS_TOKEN",
    77  		  "expires_in":7200,
    78  		  "refresh_token":"REFRESH_TOKEN",
    79  		  "openid":"OPENID",
    80  		  "scope":"SCOPE" 
    81  		}`))
    82  	})
    83  
    84  	type args struct {
    85  		appid  string
    86  		secret string
    87  		code   string
    88  	}
    89  	tests := []struct {
    90  		name                 string
    91  		args                 args
    92  		wantOauthAccessToken OauthAccessToken
    93  		wantErr              bool
    94  	}{
    95  		{name: "case1", args: args{appid: "", secret: "", code: ""}, wantOauthAccessToken: OauthAccessToken{
    96  			AccessToken:  "ACCESS_TOKEN",
    97  			ExpiresIn:    7200,
    98  			RefreshToken: "REFRESH_TOKEN",
    99  			Openid:       "OPENID",
   100  			Scope:        "SCOPE",
   101  		}, wantErr: false},
   102  	}
   103  	for _, tt := range tests {
   104  		t.Run(tt.name, func(t *testing.T) {
   105  			gotOauthAccessToken, err := GetAccessToken(tt.args.appid, tt.args.secret, tt.args.code)
   106  			if (err != nil) != tt.wantErr {
   107  				t.Errorf("GetAccessToken() error = %v, wantErr %v", err, tt.wantErr)
   108  				return
   109  			}
   110  			if !reflect.DeepEqual(gotOauthAccessToken, tt.wantOauthAccessToken) {
   111  				t.Errorf("GetAccessToken() gotOauthAccessToken = %v, want %v", gotOauthAccessToken, tt.wantOauthAccessToken)
   112  			}
   113  		})
   114  	}
   115  }
   116  
   117  func TestGetAuthorizeUrl(t *testing.T) {
   118  	type args struct {
   119  		appid       string
   120  		redirectUri string
   121  		scope       string
   122  		state       string
   123  	}
   124  	tests := []struct {
   125  		name             string
   126  		args             args
   127  		wantAuthorizeUrl string
   128  	}{
   129  		{name: "case1", args: args{appid: "appid", redirectUri: "https://fastwego.dev/api/weixin/oauth", scope: ScopeSnsapiUserinfo, state: "STATE"}, wantAuthorizeUrl: "https://open.weixin.qq.com/connect/oauth2/authorize?appid=appid&redirect_uri=https%3A%2F%2Ffastwego.dev%2Fapi%2Fweixin%2Foauth&response_type=code&scope=snsapi_userinfo&state=STATE"},
   130  		{name: "case2", args: args{appid: "appid", redirectUri: "https://fastwego.dev/api/weixin/oauth", scope: ScopeSnsapiBase, state: "STATE"}, wantAuthorizeUrl: "https://open.weixin.qq.com/connect/oauth2/authorize?appid=appid&redirect_uri=https%3A%2F%2Ffastwego.dev%2Fapi%2Fweixin%2Foauth&response_type=code&scope=snsapi_base&state=STATE"},
   131  	}
   132  	for _, tt := range tests {
   133  		t.Run(tt.name, func(t *testing.T) {
   134  			gotAuthorizeUrl := GetAuthorizeUrl(tt.args.appid, tt.args.redirectUri, tt.args.scope, tt.args.state)
   135  			fmt.Println(gotAuthorizeUrl)
   136  			if gotAuthorizeUrl != tt.wantAuthorizeUrl {
   137  				t.Errorf("GetAuthorizeUrl() = %v \n want %v", gotAuthorizeUrl, tt.wantAuthorizeUrl)
   138  			}
   139  		})
   140  	}
   141  }
   142  
   143  func TestGetUserInfo(t *testing.T) {
   144  	// Mock
   145  	MockSvrHandler.HandleFunc(apiUserInfo, func(w http.ResponseWriter, r *http.Request) {
   146  		_, _ = w.Write([]byte(`{   
   147            "openid":"OPENID",
   148  		  "nickname": "NICKNAME",
   149  		  "sex":1,
   150  		  "province":"PROVINCE",
   151  		  "city":"CITY",
   152  		  "country":"COUNTRY",
   153  		  "headimgurl":"http://thirdwx.qlogo.cn/mmopen/g3MonUZtNHkdmzicIlibx6iaFqAc56vxLSUfpb6n5WKSYVY0ChQKkiaJSgQ1dZuTOgvLLrhJbERQQ4eMsv84eavHiaiceqxibJxCfHe/46",
   154  		  "privilege":[ "PRIVILEGE1","PRIVILEGE2"     ],
   155  		  "unionid": "o6_bmasdasdsad6_2sgVt7hMZOPfL"
   156  		}`))
   157  	})
   158  
   159  	type args struct {
   160  		access_token string
   161  		openid       string
   162  		lang         string
   163  	}
   164  	tests := []struct {
   165  		name              string
   166  		args              args
   167  		wantOauthUserInfo OauthUserInfo
   168  		wantErr           bool
   169  	}{
   170  		{name: "case1", args: args{access_token: "", openid: "", lang: LANG_zh_CN}, wantOauthUserInfo: OauthUserInfo{
   171  			Openid:     "OPENID",
   172  			Nickname:   "NICKNAME",
   173  			Sex:        1,
   174  			Province:   "PROVINCE",
   175  			City:       "CITY",
   176  			Country:    "COUNTRY",
   177  			Headimgurl: "http://thirdwx.qlogo.cn/mmopen/g3MonUZtNHkdmzicIlibx6iaFqAc56vxLSUfpb6n5WKSYVY0ChQKkiaJSgQ1dZuTOgvLLrhJbERQQ4eMsv84eavHiaiceqxibJxCfHe/46",
   178  			Privilege:  []string{"PRIVILEGE1", "PRIVILEGE2"},
   179  			Unionid:    "o6_bmasdasdsad6_2sgVt7hMZOPfL",
   180  		}},
   181  	}
   182  	for _, tt := range tests {
   183  		t.Run(tt.name, func(t *testing.T) {
   184  			gotOauthUserInfo, err := GetUserInfo(tt.args.access_token, tt.args.openid, tt.args.lang)
   185  			if (err != nil) != tt.wantErr {
   186  				t.Errorf("GetUserInfo() error = %v, wantErr %v", err, tt.wantErr)
   187  				return
   188  			}
   189  			if !reflect.DeepEqual(gotOauthUserInfo, tt.wantOauthUserInfo) {
   190  				t.Errorf("GetUserInfo() gotOauthUserInfo = \n%v,\n want \n%v", gotOauthUserInfo, tt.wantOauthUserInfo)
   191  			}
   192  		})
   193  	}
   194  }
   195  
   196  func TestRefreshToken(t *testing.T) {
   197  	// Mock
   198  	MockSvrHandler.HandleFunc(apiRefreshToken, func(w http.ResponseWriter, r *http.Request) {
   199  		_, _ = w.Write([]byte(`{
   200  		  "access_token":"ACCESS_TOKEN",
   201  		  "expires_in":7200,
   202  		  "refresh_token":"REFRESH_TOKEN",
   203  		  "openid":"OPENID",
   204  		  "scope":"SCOPE" 
   205  		}`))
   206  	})
   207  	type args struct {
   208  		appid         string
   209  		refresh_token string
   210  	}
   211  	tests := []struct {
   212  		name                 string
   213  		args                 args
   214  		wantOauthAccessToken OauthAccessToken
   215  		wantErr              bool
   216  	}{
   217  		{name: "case1", args: args{appid: "", refresh_token: ""}, wantOauthAccessToken: OauthAccessToken{
   218  			AccessToken:  "ACCESS_TOKEN",
   219  			ExpiresIn:    7200,
   220  			RefreshToken: "REFRESH_TOKEN",
   221  			Openid:       "OPENID",
   222  			Scope:        "SCOPE",
   223  		}, wantErr: false},
   224  	}
   225  	for _, tt := range tests {
   226  		t.Run(tt.name, func(t *testing.T) {
   227  			gotOauthAccessToken, err := RefreshToken(tt.args.appid, tt.args.refresh_token)
   228  			if (err != nil) != tt.wantErr {
   229  				t.Errorf("RefreshToken() error = %v, wantErr %v", err, tt.wantErr)
   230  				return
   231  			}
   232  			if !reflect.DeepEqual(gotOauthAccessToken, tt.wantOauthAccessToken) {
   233  				t.Errorf("RefreshToken() gotOauthAccessToken = %v, want %v", gotOauthAccessToken, tt.wantOauthAccessToken)
   234  			}
   235  		})
   236  	}
   237  }