github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/accessor/handler.go (about)

     1  package accessor
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  
     7  	"code.cloudfoundry.org/lager"
     8  	"github.com/pf-qiu/concourse/v6/atc/auditor"
     9  )
    10  
    11  //go:generate counterfeiter net/http.Handler
    12  
    13  //go:generate counterfeiter . AccessFactory
    14  
    15  type AccessFactory interface {
    16  	Create(req *http.Request, role string) (Access, error)
    17  }
    18  
    19  func NewHandler(
    20  	logger lager.Logger,
    21  	action string,
    22  	handler http.Handler,
    23  	accessFactory AccessFactory,
    24  	auditor auditor.Auditor,
    25  	customRoles map[string]string,
    26  ) http.Handler {
    27  	return &accessorHandler{
    28  		logger:        logger,
    29  		handler:       handler,
    30  		accessFactory: accessFactory,
    31  		action:        action,
    32  		auditor:       auditor,
    33  		customRoles:   customRoles,
    34  	}
    35  }
    36  
    37  type accessorHandler struct {
    38  	logger        lager.Logger
    39  	action        string
    40  	handler       http.Handler
    41  	accessFactory AccessFactory
    42  	auditor       auditor.Auditor
    43  	customRoles   map[string]string
    44  }
    45  
    46  func (h *accessorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    47  	requiredRole := h.customRoles[h.action]
    48  	if requiredRole == "" {
    49  		requiredRole = DefaultRoles[h.action]
    50  	}
    51  
    52  	acc, err := h.accessFactory.Create(r, requiredRole)
    53  	if err != nil {
    54  		h.logger.Error("failed-to-construct-accessor", err)
    55  		w.WriteHeader(http.StatusInternalServerError)
    56  		return
    57  	}
    58  
    59  	claims := acc.Claims()
    60  
    61  	ctx := context.WithValue(r.Context(), "accessor", acc)
    62  
    63  	h.auditor.Audit(h.action, claims.UserName, r)
    64  	h.handler.ServeHTTP(w, r.WithContext(ctx))
    65  }
    66  
    67  func GetAccessor(r *http.Request) Access {
    68  	accessor := r.Context().Value("accessor")
    69  	if accessor != nil {
    70  		return accessor.(Access)
    71  	}
    72  
    73  	return &access{}
    74  }