github.com/TrueCloudLab/frostfs-api-go/v2@v2.0.0-20230228134343-196241c4e79a/signature/sign.go (about)

     1  package signature
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/TrueCloudLab/frostfs-api-go/v2/accounting"
     9  	"github.com/TrueCloudLab/frostfs-api-go/v2/container"
    10  	"github.com/TrueCloudLab/frostfs-api-go/v2/netmap"
    11  	"github.com/TrueCloudLab/frostfs-api-go/v2/object"
    12  	"github.com/TrueCloudLab/frostfs-api-go/v2/refs"
    13  	"github.com/TrueCloudLab/frostfs-api-go/v2/reputation"
    14  	"github.com/TrueCloudLab/frostfs-api-go/v2/session"
    15  	"github.com/TrueCloudLab/frostfs-api-go/v2/util/signature"
    16  )
    17  
    18  type serviceRequest interface {
    19  	GetMetaHeader() *session.RequestMetaHeader
    20  	GetVerificationHeader() *session.RequestVerificationHeader
    21  	SetVerificationHeader(*session.RequestVerificationHeader)
    22  }
    23  
    24  type serviceResponse interface {
    25  	GetMetaHeader() *session.ResponseMetaHeader
    26  	GetVerificationHeader() *session.ResponseVerificationHeader
    27  	SetVerificationHeader(*session.ResponseVerificationHeader)
    28  }
    29  
    30  type stableMarshaler interface {
    31  	StableMarshal([]byte) []byte
    32  	StableSize() int
    33  }
    34  
    35  type StableMarshalerWrapper struct {
    36  	SM stableMarshaler
    37  }
    38  
    39  type metaHeader interface {
    40  	stableMarshaler
    41  	getOrigin() metaHeader
    42  }
    43  
    44  type verificationHeader interface {
    45  	stableMarshaler
    46  
    47  	GetBodySignature() *refs.Signature
    48  	SetBodySignature(*refs.Signature)
    49  	GetMetaSignature() *refs.Signature
    50  	SetMetaSignature(*refs.Signature)
    51  	GetOriginSignature() *refs.Signature
    52  	SetOriginSignature(*refs.Signature)
    53  
    54  	setOrigin(stableMarshaler)
    55  	getOrigin() verificationHeader
    56  }
    57  
    58  type requestMetaHeader struct {
    59  	*session.RequestMetaHeader
    60  }
    61  
    62  type responseMetaHeader struct {
    63  	*session.ResponseMetaHeader
    64  }
    65  
    66  type requestVerificationHeader struct {
    67  	*session.RequestVerificationHeader
    68  }
    69  
    70  type responseVerificationHeader struct {
    71  	*session.ResponseVerificationHeader
    72  }
    73  
    74  func (h *requestMetaHeader) getOrigin() metaHeader {
    75  	return &requestMetaHeader{
    76  		RequestMetaHeader: h.GetOrigin(),
    77  	}
    78  }
    79  
    80  func (h *responseMetaHeader) getOrigin() metaHeader {
    81  	return &responseMetaHeader{
    82  		ResponseMetaHeader: h.GetOrigin(),
    83  	}
    84  }
    85  
    86  func (h *requestVerificationHeader) getOrigin() verificationHeader {
    87  	if origin := h.GetOrigin(); origin != nil {
    88  		return &requestVerificationHeader{
    89  			RequestVerificationHeader: origin,
    90  		}
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  func (h *requestVerificationHeader) setOrigin(m stableMarshaler) {
    97  	if m != nil {
    98  		h.SetOrigin(m.(*session.RequestVerificationHeader))
    99  	}
   100  }
   101  
   102  func (r *responseVerificationHeader) getOrigin() verificationHeader {
   103  	if origin := r.GetOrigin(); origin != nil {
   104  		return &responseVerificationHeader{
   105  			ResponseVerificationHeader: origin,
   106  		}
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  func (r *responseVerificationHeader) setOrigin(m stableMarshaler) {
   113  	if m != nil {
   114  		r.SetOrigin(m.(*session.ResponseVerificationHeader))
   115  	}
   116  }
   117  
   118  func (s StableMarshalerWrapper) ReadSignedData(buf []byte) ([]byte, error) {
   119  	if s.SM != nil {
   120  		return s.SM.StableMarshal(buf), nil
   121  	}
   122  
   123  	return nil, nil
   124  }
   125  
   126  func (s StableMarshalerWrapper) SignedDataSize() int {
   127  	if s.SM != nil {
   128  		return s.SM.StableSize()
   129  	}
   130  
   131  	return 0
   132  }
   133  
   134  func SignServiceMessage(key *ecdsa.PrivateKey, msg interface{}) error {
   135  	var (
   136  		body, meta, verifyOrigin stableMarshaler
   137  		verifyHdr                verificationHeader
   138  		verifyHdrSetter          func(verificationHeader)
   139  	)
   140  
   141  	switch v := msg.(type) {
   142  	case nil:
   143  		return nil
   144  	case serviceRequest:
   145  		body = serviceMessageBody(v)
   146  		meta = v.GetMetaHeader()
   147  		verifyHdr = &requestVerificationHeader{new(session.RequestVerificationHeader)}
   148  		verifyHdrSetter = func(h verificationHeader) {
   149  			v.SetVerificationHeader(h.(*requestVerificationHeader).RequestVerificationHeader)
   150  		}
   151  
   152  		if h := v.GetVerificationHeader(); h != nil {
   153  			verifyOrigin = h
   154  		}
   155  	case serviceResponse:
   156  		body = serviceMessageBody(v)
   157  		meta = v.GetMetaHeader()
   158  		verifyHdr = &responseVerificationHeader{new(session.ResponseVerificationHeader)}
   159  		verifyHdrSetter = func(h verificationHeader) {
   160  			v.SetVerificationHeader(h.(*responseVerificationHeader).ResponseVerificationHeader)
   161  		}
   162  
   163  		if h := v.GetVerificationHeader(); h != nil {
   164  			verifyOrigin = h
   165  		}
   166  	default:
   167  		panic(fmt.Sprintf("unsupported session message %T", v))
   168  	}
   169  
   170  	if verifyOrigin == nil {
   171  		// sign session message body
   172  		if err := signServiceMessagePart(key, body, verifyHdr.SetBodySignature); err != nil {
   173  			return fmt.Errorf("could not sign body: %w", err)
   174  		}
   175  	}
   176  
   177  	// sign meta header
   178  	if err := signServiceMessagePart(key, meta, verifyHdr.SetMetaSignature); err != nil {
   179  		return fmt.Errorf("could not sign meta header: %w", err)
   180  	}
   181  
   182  	// sign verification header origin
   183  	if err := signServiceMessagePart(key, verifyOrigin, verifyHdr.SetOriginSignature); err != nil {
   184  		return fmt.Errorf("could not sign origin of verification header: %w", err)
   185  	}
   186  
   187  	// wrap origin verification header
   188  	verifyHdr.setOrigin(verifyOrigin)
   189  
   190  	// update matryoshka verification header
   191  	verifyHdrSetter(verifyHdr)
   192  
   193  	return nil
   194  }
   195  
   196  func signServiceMessagePart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*refs.Signature)) error {
   197  	var sig *refs.Signature
   198  
   199  	// sign part
   200  	if err := signature.SignDataWithHandler(
   201  		key,
   202  		&StableMarshalerWrapper{part},
   203  		func(s *refs.Signature) {
   204  			sig = s
   205  		},
   206  	); err != nil {
   207  		return err
   208  	}
   209  
   210  	// write part signature
   211  	sigWrite(sig)
   212  
   213  	return nil
   214  }
   215  
   216  func VerifyServiceMessage(msg interface{}) error {
   217  	var (
   218  		meta   metaHeader
   219  		verify verificationHeader
   220  	)
   221  
   222  	switch v := msg.(type) {
   223  	case nil:
   224  		return nil
   225  	case serviceRequest:
   226  		meta = &requestMetaHeader{
   227  			RequestMetaHeader: v.GetMetaHeader(),
   228  		}
   229  
   230  		verify = &requestVerificationHeader{
   231  			RequestVerificationHeader: v.GetVerificationHeader(),
   232  		}
   233  	case serviceResponse:
   234  		meta = &responseMetaHeader{
   235  			ResponseMetaHeader: v.GetMetaHeader(),
   236  		}
   237  
   238  		verify = &responseVerificationHeader{
   239  			ResponseVerificationHeader: v.GetVerificationHeader(),
   240  		}
   241  	default:
   242  		panic(fmt.Sprintf("unsupported session message %T", v))
   243  	}
   244  
   245  	body := serviceMessageBody(msg)
   246  	size := body.StableSize()
   247  	if sz := meta.StableSize(); sz > size {
   248  		size = sz
   249  	}
   250  	if sz := verify.StableSize(); sz > size {
   251  		size = sz
   252  	}
   253  
   254  	buf := make([]byte, 0, size)
   255  	return verifyMatryoshkaLevel(body, meta, verify, buf)
   256  }
   257  
   258  func verifyMatryoshkaLevel(body stableMarshaler, meta metaHeader, verify verificationHeader, buf []byte) error {
   259  	if err := verifyServiceMessagePart(meta, verify.GetMetaSignature, buf); err != nil {
   260  		return fmt.Errorf("could not verify meta header: %w", err)
   261  	}
   262  
   263  	origin := verify.getOrigin()
   264  
   265  	if err := verifyServiceMessagePart(origin, verify.GetOriginSignature, buf); err != nil {
   266  		return fmt.Errorf("could not verify origin of verification header: %w", err)
   267  	}
   268  
   269  	if origin == nil {
   270  		if err := verifyServiceMessagePart(body, verify.GetBodySignature, buf); err != nil {
   271  			return fmt.Errorf("could not verify body: %w", err)
   272  		}
   273  
   274  		return nil
   275  	}
   276  
   277  	if verify.GetBodySignature() != nil {
   278  		return errors.New("body signature at the matryoshka upper level")
   279  	}
   280  
   281  	return verifyMatryoshkaLevel(body, meta.getOrigin(), origin, buf)
   282  }
   283  
   284  func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *refs.Signature, buf []byte) error {
   285  	return signature.VerifyDataWithSource(
   286  		&StableMarshalerWrapper{part},
   287  		sigRdr,
   288  		signature.WithBuffer(buf),
   289  	)
   290  }
   291  
   292  func serviceMessageBody(req interface{}) stableMarshaler {
   293  	switch v := req.(type) {
   294  	default:
   295  		panic(fmt.Sprintf("unsupported session message %T", req))
   296  
   297  		/* Accounting */
   298  	case *accounting.BalanceRequest:
   299  		return v.GetBody()
   300  	case *accounting.BalanceResponse:
   301  		return v.GetBody()
   302  
   303  		/* Session */
   304  	case *session.CreateRequest:
   305  		return v.GetBody()
   306  	case *session.CreateResponse:
   307  		return v.GetBody()
   308  
   309  		/* Container */
   310  	case *container.PutRequest:
   311  		return v.GetBody()
   312  	case *container.PutResponse:
   313  		return v.GetBody()
   314  	case *container.DeleteRequest:
   315  		return v.GetBody()
   316  	case *container.DeleteResponse:
   317  		return v.GetBody()
   318  	case *container.GetRequest:
   319  		return v.GetBody()
   320  	case *container.GetResponse:
   321  		return v.GetBody()
   322  	case *container.ListRequest:
   323  		return v.GetBody()
   324  	case *container.ListResponse:
   325  		return v.GetBody()
   326  	case *container.SetExtendedACLRequest:
   327  		return v.GetBody()
   328  	case *container.SetExtendedACLResponse:
   329  		return v.GetBody()
   330  	case *container.GetExtendedACLRequest:
   331  		return v.GetBody()
   332  	case *container.GetExtendedACLResponse:
   333  		return v.GetBody()
   334  	case *container.AnnounceUsedSpaceRequest:
   335  		return v.GetBody()
   336  	case *container.AnnounceUsedSpaceResponse:
   337  		return v.GetBody()
   338  
   339  		/* Object */
   340  	case *object.PutRequest:
   341  		return v.GetBody()
   342  	case *object.PutResponse:
   343  		return v.GetBody()
   344  	case *object.GetRequest:
   345  		return v.GetBody()
   346  	case *object.GetResponse:
   347  		return v.GetBody()
   348  	case *object.HeadRequest:
   349  		return v.GetBody()
   350  	case *object.HeadResponse:
   351  		return v.GetBody()
   352  	case *object.SearchRequest:
   353  		return v.GetBody()
   354  	case *object.SearchResponse:
   355  		return v.GetBody()
   356  	case *object.DeleteRequest:
   357  		return v.GetBody()
   358  	case *object.DeleteResponse:
   359  		return v.GetBody()
   360  	case *object.GetRangeRequest:
   361  		return v.GetBody()
   362  	case *object.GetRangeResponse:
   363  		return v.GetBody()
   364  	case *object.GetRangeHashRequest:
   365  		return v.GetBody()
   366  	case *object.GetRangeHashResponse:
   367  		return v.GetBody()
   368  
   369  		/* Netmap */
   370  	case *netmap.LocalNodeInfoRequest:
   371  		return v.GetBody()
   372  	case *netmap.LocalNodeInfoResponse:
   373  		return v.GetBody()
   374  	case *netmap.NetworkInfoRequest:
   375  		return v.GetBody()
   376  	case *netmap.NetworkInfoResponse:
   377  		return v.GetBody()
   378  	case *netmap.SnapshotRequest:
   379  		return v.GetBody()
   380  	case *netmap.SnapshotResponse:
   381  		return v.GetBody()
   382  
   383  		/* Reputation */
   384  	case *reputation.AnnounceLocalTrustRequest:
   385  		return v.GetBody()
   386  	case *reputation.AnnounceLocalTrustResponse:
   387  		return v.GetBody()
   388  	case *reputation.AnnounceIntermediateResultRequest:
   389  		return v.GetBody()
   390  	case *reputation.AnnounceIntermediateResultResponse:
   391  		return v.GetBody()
   392  	}
   393  }