code.vegaprotocol.io/vega@v0.79.0/datanode/contextutil/contextutil.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package contextutil
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  
    22  	uuid "github.com/satori/go.uuid"
    23  )
    24  
    25  type (
    26  	remoteIPAddrKey int
    27  	traceIDT        int
    28  	blockHeight     int
    29  )
    30  
    31  var (
    32  	clientRemoteIPAddrKey remoteIPAddrKey
    33  	traceIDKey            traceIDT
    34  	blockHeightKey        blockHeight
    35  
    36  	ErrBlockHeightMissing = errors.New("no or invalid block height set on context")
    37  )
    38  
    39  // WithRemoteIPAddr wrap the context into a new context
    40  // and embed the ip addr as a key.
    41  func WithRemoteIPAddr(ctx context.Context, addr string) context.Context {
    42  	return context.WithValue(ctx, clientRemoteIPAddrKey, addr)
    43  }
    44  
    45  // RemoteIPAddrFromContext returns the remote IP addr value stored in ctx, if any.
    46  func RemoteIPAddrFromContext(ctx context.Context) (string, bool) {
    47  	u, ok := ctx.Value(clientRemoteIPAddrKey).(string)
    48  	return u, ok
    49  }
    50  
    51  // TraceIDFromContext get traceID from context (add one if none is set).
    52  func TraceIDFromContext(ctx context.Context) (context.Context, string) {
    53  	tID := ctx.Value(traceIDKey)
    54  	if tID == nil {
    55  		if h, _ := BlockHeightFromContext(ctx); h == 0 {
    56  			ctx = context.WithValue(ctx, traceIDKey, "genesis")
    57  			return ctx, "genesis"
    58  		}
    59  		stID := uuid.NewV4().String()
    60  		ctx = context.WithValue(ctx, traceIDKey, stID)
    61  		return ctx, stID
    62  	}
    63  	stID, ok := tID.(string)
    64  	if !ok {
    65  		stID = uuid.NewV4().String()
    66  		ctx = context.WithValue(ctx, traceIDKey, stID)
    67  	}
    68  	return ctx, stID
    69  }
    70  
    71  func BlockHeightFromContext(ctx context.Context) (int64, error) {
    72  	hv := ctx.Value(blockHeightKey)
    73  	if hv == nil {
    74  		return 0, ErrBlockHeightMissing
    75  	}
    76  	h, ok := hv.(int64)
    77  	if !ok {
    78  		return 0, ErrBlockHeightMissing
    79  	}
    80  	return h, nil
    81  }
    82  
    83  // WithTraceID returns a context with a traceID value.
    84  func WithTraceID(ctx context.Context, tID string) context.Context {
    85  	return context.WithValue(ctx, traceIDKey, tID)
    86  }
    87  
    88  func WithBlockHeight(ctx context.Context, h int64) context.Context {
    89  	return context.WithValue(ctx, blockHeightKey, h)
    90  }