github.com/cloudwego/dynamicgo@v0.2.6-0.20240519101509-707f41b6b834/conv/j2t/impl.go (about)

     1  /**
     2   * Copyright 2023 CloudWeGo Authors.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package j2t
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"runtime"
    23  	"unsafe"
    24  
    25  	"github.com/cloudwego/dynamicgo/conv"
    26  	"github.com/cloudwego/dynamicgo/http"
    27  	"github.com/cloudwego/dynamicgo/internal/json"
    28  	"github.com/cloudwego/dynamicgo/internal/native"
    29  	"github.com/cloudwego/dynamicgo/internal/native/types"
    30  	"github.com/cloudwego/dynamicgo/internal/rt"
    31  	"github.com/cloudwego/dynamicgo/meta"
    32  	"github.com/cloudwego/dynamicgo/thrift"
    33  	"github.com/cloudwego/dynamicgo/thrift/base"
    34  )
    35  
    36  const (
    37  	_GUARD_SLICE_FACTOR = 1
    38  )
    39  
    40  func (self *BinaryConv) do(ctx context.Context, src []byte, desc *thrift.TypeDescriptor, buf *[]byte, req http.RequestGetter) error {
    41  	//NOTICE: output buffer must be larger than src buffer
    42  	rt.GuardSlice(buf, len(src)*_GUARD_SLICE_FACTOR)
    43  
    44  	if self.opts.EnableThriftBase {
    45  		if f := desc.Struct().GetRequestBase(); f != nil {
    46  			if err := writeRequestBaseToThrift(ctx, buf, f); err != nil {
    47  				return err
    48  			}
    49  		}
    50  	}
    51  
    52  	if len(src) == 0 {
    53  		// empty body
    54  		if self.opts.EnableHttpMapping && req != nil {
    55  			st := desc.Struct()
    56  			var reqs = thrift.NewRequiresBitmap()
    57  			st.Requires().CopyTo(reqs)
    58  			// check if any http-mapping exists
    59  			if desc.Struct().HttpMappingFields() != nil {
    60  				if err := self.writeHttpRequestToThrift(ctx, req, st, *reqs, buf, true, true); err != nil {
    61  					return err
    62  				}
    63  			}
    64  			// check if any required field exists, if have, traceback on http
    65  			// since this case it always for top-level fields,
    66  			// we should only check opts.BackTraceRequireOrTopField to decide whether to traceback
    67  			err := reqs.HandleRequires(st, self.opts.ReadHttpValueFallback, self.opts.ReadHttpValueFallback, self.opts.ReadHttpValueFallback, func(f *thrift.FieldDescriptor) error {
    68  				val, _, enc := tryGetValueFromHttp(req, f.Alias())
    69  				if err := self.writeStringValue(ctx, buf, f, val, enc, req); err != nil {
    70  					return err
    71  				}
    72  				return nil
    73  			})
    74  			if err != nil {
    75  				return newError(meta.ErrWrite, "failed to write required field", err)
    76  			}
    77  			thrift.FreeRequiresBitmap(reqs)
    78  		}
    79  		// since there is no json data, we should add struct end into buf and return
    80  		*buf = append(*buf, byte(thrift.STOP))
    81  		return nil
    82  	}
    83  
    84  	// special case for unquoted json string
    85  	if desc.Type() == thrift.STRING && src[0] != '"' {
    86  		buf := make([]byte, 0, len(src)+2)
    87  		src = json.EncodeString(buf, rt.Mem2Str(src))
    88  	}
    89  
    90  	return self.doNative(ctx, src, desc, buf, req, true)
    91  }
    92  
    93  func (self *BinaryConv) doNative(ctx context.Context, src []byte, desc *thrift.TypeDescriptor, buf *[]byte, req http.RequestGetter, top bool) (err error) {
    94  	jp := rt.Mem2Str(src)
    95  	fsm := types.NewJ2TStateMachine()
    96  	fsm.Init(0, unsafe.Pointer(desc))
    97  
    98  exec:
    99  	ret := native.J2T_FSM(fsm, buf, &jp, self.flags)
   100  	if ret != 0 {
   101  		cont, e := self.handleError(ctx, fsm, buf, src, req, ret, top)
   102  		if cont && e == nil {
   103  			goto exec
   104  		}
   105  		err = e
   106  		goto ret
   107  	}
   108  
   109  ret:
   110  	types.FreeJ2TStateMachine(fsm)
   111  	runtime.KeepAlive(desc)
   112  	return
   113  }
   114  
   115  func isJsonString(val string) bool {
   116  	if len(val) < 2 {
   117  		return false
   118  	}
   119  	
   120  	c := json.SkipBlank(val, 0) 
   121  	if c < 0 {
   122  		return false
   123  	}
   124  	s := val[c]
   125  	e := val[len(val)-1] //FIXME: may need exist blank
   126  	return (s == '{' && e == '}') || (s == '[' && e == ']') || (s == '"' && e == '"')
   127  }
   128  
   129  func (self *BinaryConv) writeStringValue(ctx context.Context, buf *[]byte, f *thrift.FieldDescriptor, val string, enc meta.Encoding, req http.RequestGetter) error {
   130  	p := thrift.BinaryProtocol{Buf: *buf}
   131  	if val == "" {
   132  		if !self.opts.WriteRequireField && f.Required() == thrift.RequiredRequireness {
   133  			// requred field not found, return error
   134  			return newError(meta.ErrMissRequiredField, fmt.Sprintf("required field '%s' not found", f.Name()), nil)
   135  		}
   136  		if !self.opts.WriteOptionalField && f.Required() == thrift.OptionalRequireness {
   137  			return nil
   138  		}
   139  		if !self.opts.WriteDefaultField && f.Required() == thrift.DefaultRequireness {
   140  			return nil
   141  		}
   142  		if err := p.WriteFieldBegin(f.Name(), f.Type().Type(), f.ID()); err != nil {
   143  			return newError(meta.ErrWrite, fmt.Sprintf("failed to write field '%s' tag", f.Name()), err)
   144  		}
   145  		if err := p.WriteDefaultOrEmpty(f); err != nil {
   146  			return newError(meta.ErrWrite, fmt.Sprintf("failed to write empty value of field '%s'", f.Name()), err)
   147  		}
   148  	} else {
   149  		if err := p.WriteFieldBegin(f.Name(), f.Type().Type(), f.ID()); err != nil {
   150  			return newError(meta.ErrWrite, fmt.Sprintf("failed to write field '%s' tag", f.Name()), err)
   151  		}
   152  		// not http-encoded value, write directly into buf
   153  		if enc == meta.EncodingThriftBinary {
   154  			p.Buf = append(p.Buf, val...)
   155  			goto BACK
   156  		} else if enc == meta.EncodingText || !f.Type().Type().IsComplex() || !isJsonString(val) {
   157  			if err := p.WriteStringWithDesc(val, f.Type(), self.opts.DisallowUnknownField, !self.opts.NoBase64Binary); err != nil {
   158  				return newError(meta.ErrConvert, fmt.Sprintf("failed to write field '%s' value", f.Name()), err)
   159  			}
   160  		} else if enc == meta.EncodingJSON {
   161  			// for nested type, we regard it as a json string and convert it directly
   162  			if err := self.doNative(ctx, rt.Str2Mem(val), f.Type(), &p.Buf, req, false); err != nil {
   163  				return newError(meta.ErrConvert, fmt.Sprintf("failed to convert value of field '%s'", f.Name()), err)
   164  			}
   165  		// try text encoding, see thrift.EncodeText
   166  		} else {
   167  			return newError(meta.ErrConvert, fmt.Sprintf("unsupported http-mapping encoding %v for '%s'", enc, f.Name()), nil)
   168  		}
   169  	}
   170  BACK:
   171  	*buf = p.Buf
   172  	return nil
   173  }
   174  
   175  func writeRequestBaseToThrift(ctx context.Context, buf *[]byte, field *thrift.FieldDescriptor) error {
   176  	bobj := ctx.Value(conv.CtxKeyThriftReqBase)
   177  	if bobj != nil {
   178  		if b, ok := bobj.(*base.Base); ok && b != nil {
   179  			l := len(*buf)
   180  			n := b.BLength()
   181  			rt.GuardSlice(buf, 3+n)
   182  			*buf = (*buf)[:l+3]
   183  			thrift.BinaryEncoding{}.EncodeFieldBegin((*buf)[l:l+3], field.Type().Type(), field.ID())
   184  			l = len(*buf)
   185  			*buf = (*buf)[:l+n]
   186  			b.FastWrite((*buf)[l : l+n])
   187  		}
   188  	}
   189  	return nil
   190  }
   191  
   192  func (self *BinaryConv) writeHttpRequestToThrift(ctx context.Context, req http.RequestGetter, desc *thrift.StructDescriptor, reqs thrift.RequiresBitmap, buf *[]byte, nobody bool, top bool) (err error) {
   193  	if req == nil {
   194  		return newError(meta.ErrInvalidParam, "http request is nil", nil)
   195  	}
   196  	fs := desc.HttpMappingFields()
   197  	for _, f := range fs {
   198  		var ok bool
   199  		var val string
   200  		var httpEnc meta.Encoding
   201  		// loop http mapping until first non-null value
   202  		for _, hm := range f.HTTPMappings() {
   203  			v, err := hm.Request(ctx, req, f)
   204  			if err == nil {
   205  				httpEnc = hm.Encoding()
   206  				ok = true
   207  				val = v
   208  				break
   209  			}
   210  		}
   211  		if !ok {
   212  			// no json body, check if return error
   213  			if nobody {
   214  				if f.Required() == thrift.RequiredRequireness && !self.opts.WriteRequireField {
   215  					return newError(meta.ErrNotFound, fmt.Sprintf("not found http value of field %d:'%s'", f.ID(), f.Name()), nil)
   216  				}
   217  				if !self.opts.WriteDefaultField && f.Required() == thrift.DefaultRequireness {
   218  					continue
   219  				}
   220  				if !self.opts.WriteOptionalField && f.Required() == thrift.OptionalRequireness {
   221  					continue
   222  				}
   223  			} else {
   224  				// NOTICE: if no value found, tracebak on current json layeer to find value
   225  				// it must be a top level field or required field
   226  				if self.opts.ReadHttpValueFallback {
   227  					reqs.Set(f.ID(), thrift.RequiredRequireness)
   228  					continue
   229  				}
   230  			}
   231  		}
   232  
   233  		reqs.Set(f.ID(), thrift.OptionalRequireness)
   234  		if err := self.writeStringValue(ctx, buf, f, val, httpEnc, req); err != nil {
   235  			return err
   236  		}
   237  	}
   238  
   239  	// p.Recycle()
   240  	return
   241  }
   242  
   243  func (self *BinaryConv) handleUnmatchedFields(ctx context.Context, fsm *types.J2TStateMachine, desc *thrift.StructDescriptor, buf *[]byte, pos int, req http.RequestGetter, top bool) (bool, error) {
   244  	if req == nil {
   245  		return false, newError(meta.ErrInvalidParam, "http request is nil", nil)
   246  	}
   247  
   248  	// write unmatched fields
   249  	for _, id := range fsm.FieldCache {
   250  		f := desc.FieldById(thrift.FieldID(id))
   251  		if f == nil {
   252  			if self.opts.DisallowUnknownField {
   253  				return false, newError(meta.ErrConvert, fmt.Sprintf("unknown field id %d", id), nil)
   254  			}
   255  			continue
   256  		}
   257  		var val string
   258  		var enc = meta.EncodingText
   259  		if self.opts.TracebackRequredOrRootFields && (top || f.Required() == thrift.RequiredRequireness) {
   260  			// try get value from http
   261  			val, _, enc = tryGetValueFromHttp(req, f.Alias())
   262  		}
   263  		if err := self.writeStringValue(ctx, buf, f, val, enc, req); err != nil {
   264  			return false, err
   265  		}
   266  	}
   267  
   268  	// write STRUCT end
   269  	*buf = append(*buf, byte(thrift.STOP))
   270  
   271  	// clear field cache
   272  	fsm.FieldCache = fsm.FieldCache[:0]
   273  	// drop current J2T state
   274  	fsm.SP--
   275  	if fsm.SP > 0 {
   276  		// NOTICE: if j2t_exec haven't finished, we should set current position to next json
   277  		fsm.SetPos(pos)
   278  		return true, nil
   279  	} else {
   280  		return false, nil
   281  	}
   282  }
   283  
   284  // searching sequence: url -> [post] -> query -> header -> [body root]
   285  func tryGetValueFromHttp(req http.RequestGetter, key string) (string, bool, meta.Encoding) {
   286  	if req == nil {
   287  		return "", false, meta.EncodingText
   288  	}
   289  	if v := req.GetParam(key); v != "" {
   290  		return v, true, meta.EncodingJSON
   291  	}
   292  	if v := req.GetQuery(key); v != "" {
   293  		return v, true, meta.EncodingJSON
   294  	}
   295  	if v := req.GetHeader(key); v != "" {
   296  		return v, true, meta.EncodingJSON
   297  	}
   298  	if v := req.GetCookie(key); v != "" {
   299  		return v, true, meta.EncodingJSON
   300  	}
   301  	if v := req.GetMapBody(key); v != "" {
   302  		return v, true, meta.EncodingJSON
   303  	}
   304  	return "", false, meta.EncodingText
   305  }
   306  
   307  func (self *BinaryConv) handleValueMapping(ctx context.Context, fsm *types.J2TStateMachine, desc *thrift.StructDescriptor, buf *[]byte, pos int, src []byte) (bool, error) {
   308  	p := thrift.BinaryProtocol{}
   309  	p.Buf = *buf
   310  
   311  	// write unmatched fields
   312  	fv := fsm.FieldValueCache
   313  	f := desc.FieldById(thrift.FieldID(fv.FieldID))
   314  	if f == nil {
   315  		return false, newError(meta.ErrConvert, fmt.Sprintf("unknown field id %d for value-mapping", fv.FieldID), nil)
   316  	}
   317  	// write field tag
   318  	if err := p.WriteFieldBegin(f.Name(), f.Type().Type(), f.ID()); err != nil {
   319  		return false, newError(meta.ErrWrite, fmt.Sprintf("failed to write field '%s' tag", f.Name()), err)
   320  	}
   321  	// truncate src to value
   322  	if int(fv.ValEnd) >= len(src) || int(fv.ValBegin) < 0 {
   323  		return false, newError(meta.ErrConvert, "invalid value-mapping position", nil)
   324  	}
   325  	jdata := src[fv.ValBegin:fv.ValEnd]
   326  	// call ValueMapping interface
   327  	if err := f.ValueMapping().Write(ctx, &p, f, jdata); err != nil {
   328  		return false, newError(meta.ErrConvert, fmt.Sprintf("failed to convert field '%s' value", f.Name()), err)
   329  	}
   330  
   331  	*buf = p.Buf
   332  	// clear field cache
   333  	// fsm.FieldValueCache = fsm.FieldValueCache[:0]
   334  
   335  	// drop current J2T state
   336  	fsm.SP--
   337  	if fsm.SP > 0 {
   338  		// NOTICE: if j2t_exec haven't finished, we should set current position to next json
   339  		fsm.SetPos(pos)
   340  		return true, nil
   341  	} else {
   342  		return false, nil
   343  	}
   344  }