github.com/cnotch/ipchub@v1.1.0/service/apis.go (about)

     1  // Copyright (c) 2019,CAOHONGJU All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package service
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/json"
    10  	"io"
    11  	"net/http"
    12  	"net/url"
    13  	"path"
    14  	"runtime"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/cnotch/apirouter"
    22  	"github.com/cnotch/ipchub/config"
    23  	"github.com/cnotch/ipchub/media"
    24  	"github.com/cnotch/ipchub/provider/auth"
    25  	"github.com/cnotch/ipchub/provider/route"
    26  	"github.com/cnotch/ipchub/stats"
    27  )
    28  
    29  const (
    30  	usernameHeaderKey = "user_name_in_token"
    31  )
    32  
    33  var (
    34  	buffers = sync.Pool{
    35  		New: func() interface{} {
    36  			return bytes.NewBuffer(make([]byte, 0, 1024*2))
    37  		},
    38  	}
    39  	noAuthRequired = map[string]bool{
    40  		"/api/v1/login":        true,
    41  		"/api/v1/server":       true,
    42  		"/api/v1/runtime":      true,
    43  		"/api/v1/refreshtoken": true,
    44  	}
    45  )
    46  
    47  var crossdomainxml = []byte(
    48  	`<?xml version="1.0" ?><cross-domain-policy>
    49  			<allow-access-from domain="*" />
    50  			<allow-http-request-headers-from domain="*" headers="*"/>
    51  		</cross-domain-policy>`)
    52  
    53  func (s *Service) initApis(mux *http.ServeMux) {
    54  	api := apirouter.NewForGRPC(
    55  		// 系统信息类API
    56  		apirouter.POST("/api/v1/login", s.onLogin),
    57  		apirouter.GET("/api/v1/server", s.onGetServerInfo),
    58  		apirouter.GET("/api/v1/runtime", s.onGetRuntime),
    59  		apirouter.GET("/api/v1/refreshtoken", s.onRefreshToken),
    60  
    61  		// 流管理API
    62  		apirouter.GET("/api/v1/streams", s.onListStreams),
    63  		apirouter.GET("/api/v1/streams/{path=**}", s.onGetStreamInfo),
    64  		apirouter.DELETE("/api/v1/streams/{path=**}", s.onStopStream),
    65  		apirouter.DELETE("/api/v1/streams/{path=**}:consumer", s.onStopConsumer),
    66  
    67  		// 路由管理API
    68  		apirouter.GET("/api/v1/routes", s.onListRoutes),
    69  		apirouter.GET("/api/v1/routes/{pattern=**}", s.onGetRoute),
    70  		apirouter.DELETE("/api/v1/routes/{pattern=**}", s.onDelRoute),
    71  		apirouter.POST("/api/v1/routes", s.onSaveRoute),
    72  
    73  		// 用户管理API
    74  		apirouter.GET("/api/v1/users", s.onListUsers),
    75  		apirouter.GET("/api/v1/users/{userName=*}", s.onGetUser),
    76  		apirouter.DELETE("/api/v1/users/{userName=*}", s.onDelUser),
    77  		apirouter.POST("/api/v1/users", s.onSaveUser),
    78  	)
    79  
    80  	iterc := apirouter.ChainInterceptor(apirouter.PreInterceptor(s.authInterceptor),
    81  		apirouter.PreInterceptor(roleInterceptor))
    82  
    83  	// api add to mux
    84  	mux.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) {
    85  		if path.Base(r.URL.Path) == "crossdomain.xml" {
    86  			w.Header().Set("Content-Type", "application/xml")
    87  			w.Write(crossdomainxml)
    88  			return
    89  		}
    90  
    91  		path := strings.ToLower(r.URL.Path)
    92  		if _, ok := noAuthRequired[path]; ok || iterc.PreHandle(w, r) {
    93  			w.Header().Set("Access-Control-Allow-Origin", "*")
    94  			api.ServeHTTP(w, r)
    95  		}
    96  	})
    97  }
    98  
    99  // 刷新Token
   100  func (s *Service) onRefreshToken(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   101  	token := r.URL.Query().Get("token")
   102  	if token != "" {
   103  		newtoken := s.tokens.Refresh(token)
   104  		if newtoken != nil {
   105  			if err := jsonTo(w, newtoken); err != nil {
   106  				http.Error(w, err.Error(), http.StatusInternalServerError)
   107  			}
   108  			return
   109  		}
   110  	}
   111  
   112  	http.Error(w, "Token is not valid", http.StatusUnauthorized)
   113  	return
   114  }
   115  
   116  // 登录
   117  func (s *Service) onLogin(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   118  	type UserCredentials struct {
   119  		Username string `json:"username"`
   120  		Password string `json:"password"`
   121  	}
   122  
   123  	// 提取凭证
   124  	var uc UserCredentials
   125  
   126  	if err := json.NewDecoder(r.Body).Decode(&uc); err != nil {
   127  		http.Error(w, "未提供用户名或密码", http.StatusForbidden)
   128  		return
   129  	}
   130  
   131  	// 尝试 Form解析
   132  	if len(uc.Username) == 0 || len(uc.Password) == 0 {
   133  		http.Error(w, "用户名或密码错误", http.StatusForbidden)
   134  		return
   135  	}
   136  
   137  	// 验证用户和密码
   138  	u := auth.Get(uc.Username)
   139  	if u == nil || u.ValidatePassword(uc.Password) != nil {
   140  		http.Error(w, "用户名或密码错误", http.StatusForbidden)
   141  		return
   142  	}
   143  
   144  	// 新建Token,并返回
   145  	token := s.tokens.NewToken(u.Name)
   146  
   147  	if err := jsonTo(w, token); err != nil {
   148  		http.Error(w, err.Error(), http.StatusInternalServerError)
   149  	}
   150  }
   151  
   152  // 获取运行时信息
   153  func (s *Service) onGetServerInfo(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   154  	type server struct {
   155  		Vendor   string `json:"vendor"`
   156  		Name     string `json:"name"`
   157  		Version  string `json:"version"`
   158  		OS       string `json:"os"`
   159  		Arch     string `json:"arch"`
   160  		StartOn  string `json:"start_on"`
   161  		Duration string `json:"duration"`
   162  	}
   163  	srv := server{
   164  		Vendor:   config.Vendor,
   165  		Name:     config.Name,
   166  		Version:  config.Version,
   167  		OS:       strings.Title(runtime.GOOS),
   168  		Arch:     strings.ToUpper(runtime.GOARCH),
   169  		StartOn:  stats.StartingTime.Format(time.RFC3339Nano),
   170  		Duration: time.Now().Sub(stats.StartingTime).String(),
   171  	}
   172  
   173  	if err := jsonTo(w, &srv); err != nil {
   174  		http.Error(w, err.Error(), http.StatusInternalServerError)
   175  	}
   176  }
   177  
   178  // 获取运行时信息
   179  func (s *Service) onGetRuntime(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   180  	const extraKey = "extra"
   181  
   182  	type sccc struct {
   183  		SC int `json:"sources"`
   184  		CC int `json:"consumers"`
   185  	}
   186  	type runtime struct {
   187  		On      string            `json:"on"`
   188  		Proc    stats.Proc        `json:"proc"`
   189  		Streams sccc              `json:"streams"`
   190  		Rtsp    stats.ConnsSample `json:"rtsp"`
   191  		Flv     stats.ConnsSample `json:"flv"`
   192  		Extra   *stats.Runtime    `json:"extra,omitempty"`
   193  	}
   194  	sc, cc := media.Count()
   195  
   196  	rt := runtime{
   197  		On:      time.Now().Format(time.RFC3339Nano),
   198  		Proc:    stats.MeasureRuntime(),
   199  		Streams: sccc{sc, cc},
   200  		Rtsp:    stats.RtspConns.GetSample(),
   201  		Flv:     stats.FlvConns.GetSample(),
   202  	}
   203  
   204  	params := r.URL.Query()
   205  	if strings.TrimSpace(params.Get(extraKey)) == "1" {
   206  		rt.Extra = stats.MeasureFullRuntime()
   207  	}
   208  
   209  	if err := jsonTo(w, &rt); err != nil {
   210  		http.Error(w, err.Error(), http.StatusInternalServerError)
   211  	}
   212  }
   213  
   214  func (s *Service) onListStreams(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   215  	params := r.URL.Query()
   216  	pageSize, pageToken, err := listParamers(params)
   217  	if err != nil {
   218  		http.Error(w, err.Error(), http.StatusInternalServerError)
   219  		return
   220  	}
   221  
   222  	includeCS := strings.TrimSpace(params.Get("c")) == "1"
   223  
   224  	count, sinfos := media.Infos(pageToken, pageSize, includeCS)
   225  	type streamInfos struct {
   226  		Total         int                 `json:"total"`
   227  		NextPageToken string              `json:"next_page_token"`
   228  		Streams       []*media.StreamInfo `json:"streams,omitempty"`
   229  	}
   230  
   231  	list := &streamInfos{
   232  		Total:   count,
   233  		Streams: sinfos,
   234  	}
   235  	if len(sinfos) > 0 {
   236  		list.NextPageToken = sinfos[len(sinfos)-1].Path
   237  	}
   238  
   239  	if err := jsonTo(w, list); err != nil {
   240  		http.Error(w, err.Error(), http.StatusInternalServerError)
   241  	}
   242  }
   243  
   244  func (s *Service) onGetStreamInfo(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   245  	path := pathParams.ByName("path")
   246  
   247  	var rt *media.Stream
   248  
   249  	rt = media.Get(path)
   250  	if rt == nil {
   251  		http.NotFound(w, r)
   252  		return
   253  	}
   254  
   255  	params := r.URL.Query()
   256  	includeCS := strings.TrimSpace(params.Get("c")) == "1"
   257  
   258  	si := rt.Info(includeCS)
   259  
   260  	if err := jsonTo(w, si); err != nil {
   261  		http.Error(w, err.Error(), http.StatusInternalServerError)
   262  	}
   263  }
   264  
   265  func (s *Service) onStopStream(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   266  	path := pathParams.ByName("path")
   267  
   268  	var rt *media.Stream
   269  
   270  	rt = media.Get(path)
   271  	if rt != nil {
   272  		rt.Close()
   273  	}
   274  
   275  	w.WriteHeader(http.StatusOK)
   276  }
   277  
   278  func (s *Service) onStopConsumer(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   279  	path := pathParams.ByName("path")
   280  	param := r.URL.Query().Get("cid")
   281  	no, err := strconv.ParseInt(param, 10, 64)
   282  	if err != nil {
   283  		http.Error(w, err.Error(), http.StatusInternalServerError)
   284  		return
   285  	}
   286  
   287  	var rt *media.Stream
   288  	rt = media.Get(path)
   289  	if rt != nil {
   290  		rt.StopConsume(media.CID(no))
   291  	}
   292  
   293  	w.WriteHeader(http.StatusOK)
   294  }
   295  
   296  func (s *Service) onListRoutes(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   297  	params := r.URL.Query()
   298  	pageSize, pageToken, err := listParamers(params)
   299  	if err != nil {
   300  		http.Error(w, err.Error(), http.StatusInternalServerError)
   301  		return
   302  	}
   303  
   304  	routes := route.All()
   305  	sort.Slice(routes, func(i, j int) bool {
   306  		return routes[i].Pattern < routes[j].Pattern
   307  	})
   308  
   309  	begini := 0
   310  	for _, r1 := range routes {
   311  		if r1.Pattern <= pageToken {
   312  			begini++
   313  			continue
   314  		}
   315  		break
   316  	}
   317  
   318  	type routeInfos struct {
   319  		Total         int            `json:"total"`
   320  		NextPageToken string         `json:"next_page_token"`
   321  		Routes        []*route.Route `json:"routes,omitempty"`
   322  	}
   323  
   324  	list := &routeInfos{
   325  		Total:         len(routes),
   326  		NextPageToken: pageToken,
   327  		Routes:        make([]*route.Route, 0, pageSize),
   328  	}
   329  
   330  	j := 0
   331  	for i := begini; i < len(routes) && j < pageSize; i++ {
   332  		j++
   333  		list.Routes = append(list.Routes, routes[i])
   334  		list.NextPageToken = routes[i].Pattern
   335  	}
   336  
   337  	if err := jsonTo(w, list); err != nil {
   338  		http.Error(w, err.Error(), http.StatusInternalServerError)
   339  	}
   340  }
   341  
   342  func (s *Service) onGetRoute(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   343  	pattern := pathParams.ByName("pattern")
   344  	r1 := route.Get(pattern)
   345  	if r1 == nil {
   346  		http.NotFound(w, r)
   347  		return
   348  	}
   349  
   350  	if err := jsonTo(w, r1); err != nil {
   351  		http.Error(w, err.Error(), http.StatusInternalServerError)
   352  	}
   353  }
   354  
   355  func (s *Service) onDelRoute(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   356  	pattern := pathParams.ByName("pattern")
   357  	err := route.Del(pattern)
   358  	if err != nil {
   359  		http.Error(w, err.Error(), http.StatusInternalServerError)
   360  	} else {
   361  		w.WriteHeader(http.StatusOK)
   362  	}
   363  }
   364  
   365  func (s *Service) onSaveRoute(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   366  	r1 := &route.Route{}
   367  	err := json.NewDecoder(r.Body).Decode(r1)
   368  	if err != nil {
   369  		http.Error(w, err.Error(), http.StatusInternalServerError)
   370  		return
   371  	}
   372  
   373  	err = route.Save(r1)
   374  	if err != nil {
   375  		http.Error(w, err.Error(), http.StatusInternalServerError)
   376  	} else {
   377  		w.WriteHeader(http.StatusOK)
   378  	}
   379  }
   380  
   381  func (s *Service) onListUsers(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   382  	params := r.URL.Query()
   383  	pageSize, pageToken, err := listParamers(params)
   384  	if err != nil {
   385  		http.Error(w, err.Error(), http.StatusInternalServerError)
   386  		return
   387  	}
   388  
   389  	users := auth.All()
   390  	sort.Slice(users, func(i, j int) bool {
   391  		return users[i].Name < users[j].Name
   392  	})
   393  
   394  	begini := 0
   395  	for _, u := range users {
   396  		if u.Name <= pageToken {
   397  			begini++
   398  			continue
   399  		}
   400  		break
   401  	}
   402  
   403  	type userInfos struct {
   404  		Total         int         `json:"total"`
   405  		NextPageToken string      `json:"next_page_token"`
   406  		Users         []auth.User `json:"users,omitempty"`
   407  	}
   408  
   409  	list := &userInfos{
   410  		Total:         len(users),
   411  		NextPageToken: pageToken,
   412  		Users:         make([]auth.User, 0, pageSize),
   413  	}
   414  
   415  	j := 0
   416  	for i := begini; i < len(users) && j < pageSize; i++ {
   417  		j++
   418  		u := *users[i]
   419  		u.Password = ""
   420  		list.Users = append(list.Users, u)
   421  		list.NextPageToken = u.Name
   422  	}
   423  
   424  	if err := jsonTo(w, list); err != nil {
   425  		http.Error(w, err.Error(), http.StatusInternalServerError)
   426  	}
   427  }
   428  
   429  func (s *Service) onGetUser(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   430  	userName := pathParams.ByName("userName")
   431  	u := auth.Get(userName)
   432  	if u == nil {
   433  		http.NotFound(w, r)
   434  		return
   435  	}
   436  
   437  	u2 := *u
   438  	u2.Password = ""
   439  	if err := jsonTo(w, &u2); err != nil {
   440  		http.Error(w, err.Error(), http.StatusInternalServerError)
   441  	}
   442  }
   443  
   444  func (s *Service) onDelUser(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   445  	userName := pathParams.ByName("userName")
   446  	err := auth.Del(userName)
   447  	if err != nil {
   448  		http.Error(w, err.Error(), http.StatusInternalServerError)
   449  	} else {
   450  		w.WriteHeader(http.StatusOK)
   451  	}
   452  }
   453  
   454  func (s *Service) onSaveUser(w http.ResponseWriter, r *http.Request, pathParams apirouter.Params) {
   455  	u := &auth.User{}
   456  	err := json.NewDecoder(r.Body).Decode(u)
   457  	if err != nil {
   458  		http.Error(w, err.Error(), http.StatusInternalServerError)
   459  		return
   460  	}
   461  
   462  	updatePassword := r.URL.Query().Get("update_password") == "1"
   463  	err = auth.Save(u, updatePassword)
   464  	if err != nil {
   465  		http.Error(w, err.Error(), http.StatusInternalServerError)
   466  	} else {
   467  		w.WriteHeader(http.StatusOK)
   468  	}
   469  }
   470  
   471  func jsonTo(w io.Writer, o interface{}) error {
   472  	formatted := buffers.Get().(*bytes.Buffer)
   473  	formatted.Reset()
   474  	defer buffers.Put(formatted)
   475  
   476  	body, err := json.Marshal(o)
   477  	if err != nil {
   478  		return err
   479  	}
   480  
   481  	if err := json.Indent(formatted, body, "", "\t"); err != nil {
   482  		return err
   483  	}
   484  
   485  	if _, err := w.Write(formatted.Bytes()); err != nil {
   486  		return err
   487  	}
   488  	return nil
   489  }
   490  
   491  func listParamers(params url.Values) (pageSize int, pageToken string, err error) {
   492  	pageSizeStr := params.Get("page_size")
   493  	pageSize = 20
   494  	if pageSizeStr != "" {
   495  		var err error
   496  		pageSize, err = strconv.Atoi(pageSizeStr)
   497  		if err != nil {
   498  			return pageSize, pageToken, err
   499  		}
   500  	}
   501  	pageToken = params.Get("page_token")
   502  	return
   503  }
   504  
   505  // ?token=
   506  func (s *Service) authInterceptor(w http.ResponseWriter, r *http.Request) bool {
   507  	token := r.URL.Query().Get("token")
   508  	if token != "" {
   509  		username := s.tokens.AccessCheck(token)
   510  		if username != "" {
   511  			r.Header.Set(usernameHeaderKey, username)
   512  			return true // 继续执行
   513  		}
   514  	}
   515  
   516  	http.Error(w, "Token is not valid", http.StatusUnauthorized)
   517  	return false
   518  }
   519  
   520  func roleInterceptor(w http.ResponseWriter, r *http.Request) bool {
   521  	// 流查询方法,无需管理员身份
   522  	if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/api/v1/streams") {
   523  		return true
   524  	}
   525  
   526  	userName := r.Header.Get(usernameHeaderKey)
   527  	u := auth.Get(userName)
   528  	if u == nil || !u.Admin {
   529  		http.Error(w /*http.StatusText(http.StatusForbidden)*/, "访问被拒绝,请用管理员登录", http.StatusForbidden)
   530  		return false
   531  	}
   532  
   533  	return true
   534  }