dubbo.apache.org/dubbo-go/v3@v3.1.1/protocol/rest/rest_protocol.go (about)

     1  /*
     2   * Licensed to the Apache Software Foundation (ASF) under one or more
     3   * contributor license agreements.  See the NOTICE file distributed with
     4   * this work for additional information regarding copyright ownership.
     5   * The ASF licenses this file to You under the Apache License, Version 2.0
     6   * (the "License"); you may not use this file except in compliance with
     7   * the License.  You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   */
    17  
    18  package rest
    19  
    20  import (
    21  	"sync"
    22  	"time"
    23  )
    24  
    25  import (
    26  	"github.com/dubbogo/gost/log/logger"
    27  )
    28  
    29  import (
    30  	"dubbo.apache.org/dubbo-go/v3/common"
    31  	"dubbo.apache.org/dubbo-go/v3/common/constant"
    32  	"dubbo.apache.org/dubbo-go/v3/common/extension"
    33  	"dubbo.apache.org/dubbo-go/v3/protocol"
    34  	"dubbo.apache.org/dubbo-go/v3/protocol/rest/client"
    35  	_ "dubbo.apache.org/dubbo-go/v3/protocol/rest/client/client_impl"
    36  	rest_config "dubbo.apache.org/dubbo-go/v3/protocol/rest/config"
    37  	_ "dubbo.apache.org/dubbo-go/v3/protocol/rest/config/reader"
    38  	"dubbo.apache.org/dubbo-go/v3/protocol/rest/server"
    39  	_ "dubbo.apache.org/dubbo-go/v3/protocol/rest/server/server_impl"
    40  )
    41  
    42  var restProtocol *RestProtocol
    43  
    44  const REST = "rest"
    45  
    46  // nolint
    47  func init() {
    48  	extension.SetProtocol(REST, GetRestProtocol)
    49  }
    50  
    51  // nolint
    52  type RestProtocol struct {
    53  	protocol.BaseProtocol
    54  	serverLock sync.Mutex
    55  	serverMap  map[string]server.RestServer
    56  	clientLock sync.Mutex
    57  	clientMap  map[client.RestOptions]client.RestClient
    58  }
    59  
    60  // NewRestProtocol returns a RestProtocol
    61  func NewRestProtocol() *RestProtocol {
    62  	return &RestProtocol{
    63  		BaseProtocol: protocol.NewBaseProtocol(),
    64  		serverMap:    make(map[string]server.RestServer, 8),
    65  		clientMap:    make(map[client.RestOptions]client.RestClient, 8),
    66  	}
    67  }
    68  
    69  // Export export rest service
    70  func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter {
    71  	url := invoker.GetURL()
    72  	serviceKey := url.ServiceKey()
    73  	exporter := NewRestExporter(serviceKey, invoker, rp.ExporterMap())
    74  	id := url.GetParam(constant.BeanNameKey, "")
    75  	restServiceConfig := rest_config.GetRestProviderServiceConfig(id)
    76  	if restServiceConfig == nil {
    77  		logger.Errorf("%s service doesn't has provider config", url.Path)
    78  		return nil
    79  	}
    80  	rp.SetExporterMap(serviceKey, exporter)
    81  	restServer := rp.getServer(url, restServiceConfig.Server)
    82  	for _, methodConfig := range restServiceConfig.RestMethodConfigsMap {
    83  		restServer.Deploy(methodConfig, server.GetRouteFunc(invoker, methodConfig))
    84  	}
    85  	return exporter
    86  }
    87  
    88  // Refer create rest service reference
    89  func (rp *RestProtocol) Refer(url *common.URL) protocol.Invoker {
    90  	// create rest_invoker
    91  	// todo fix timeout config
    92  	// start
    93  	requestTimeout := time.Duration(3 * time.Second)
    94  	requestTimeoutStr := url.GetParam(constant.TimeoutKey, "3s")
    95  	connectTimeout := requestTimeout // config.GetConsumerConfig().ConnectTimeout
    96  	// end
    97  	if t, err := time.ParseDuration(requestTimeoutStr); err == nil {
    98  		requestTimeout = t
    99  	}
   100  	id := url.GetParam(constant.BeanNameKey, "")
   101  	restServiceConfig := rest_config.GetRestConsumerServiceConfig(id)
   102  	if restServiceConfig == nil {
   103  		logger.Errorf("%s service doesn't has consumer config", url.Path)
   104  		return nil
   105  	}
   106  	restOptions := client.RestOptions{RequestTimeout: requestTimeout, ConnectTimeout: connectTimeout}
   107  	restClient := rp.getClient(restOptions, restServiceConfig.Client)
   108  	invoker := NewRestInvoker(url, &restClient, restServiceConfig.RestMethodConfigsMap)
   109  	rp.SetInvokers(invoker)
   110  	return invoker
   111  }
   112  
   113  // nolint
   114  func (rp *RestProtocol) getServer(url *common.URL, serverType string) server.RestServer {
   115  	restServer, ok := rp.serverMap[url.Location]
   116  	if ok {
   117  		return restServer
   118  	}
   119  	_, ok = rp.ExporterMap().Load(url.ServiceKey())
   120  	if !ok {
   121  		panic("[RestProtocol]" + url.ServiceKey() + "is not existing")
   122  	}
   123  	rp.serverLock.Lock()
   124  	defer rp.serverLock.Unlock()
   125  	restServer, ok = rp.serverMap[url.Location]
   126  	if ok {
   127  		return restServer
   128  	}
   129  	restServer = extension.GetNewRestServer(serverType)
   130  	restServer.Start(url)
   131  	rp.serverMap[url.Location] = restServer
   132  	return restServer
   133  }
   134  
   135  // nolint
   136  func (rp *RestProtocol) getClient(restOptions client.RestOptions, clientType string) client.RestClient {
   137  	restClient, ok := rp.clientMap[restOptions]
   138  	if ok {
   139  		return restClient
   140  	}
   141  	rp.clientLock.Lock()
   142  	defer rp.clientLock.Unlock()
   143  	restClient, ok = rp.clientMap[restOptions]
   144  	if ok {
   145  		return restClient
   146  	}
   147  	restClient = extension.GetNewRestClient(clientType, &restOptions)
   148  	rp.clientMap[restOptions] = restClient
   149  	return restClient
   150  }
   151  
   152  // Destroy destroy rest service
   153  func (rp *RestProtocol) Destroy() {
   154  	// destroy rest_server
   155  	rp.BaseProtocol.Destroy()
   156  	for key, tmpServer := range rp.serverMap {
   157  		tmpServer.Destroy()
   158  		delete(rp.serverMap, key)
   159  	}
   160  	for key := range rp.clientMap {
   161  		delete(rp.clientMap, key)
   162  	}
   163  }
   164  
   165  // GetRestProtocol get a rest protocol
   166  func GetRestProtocol() protocol.Protocol {
   167  	if restProtocol == nil {
   168  		restProtocol = NewRestProtocol()
   169  	}
   170  	return restProtocol
   171  }