github.com/grumpyhome/grumpy@v0.3.1-0.20201208125205-7b775405bdf1/grumpy-runtime-src/runtime/complex.go (about)

     1  // Copyright 2016 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package grumpy
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"math"
    21  	"math/cmplx"
    22  	"reflect"
    23  	"regexp"
    24  	"strconv"
    25  	"strings"
    26  )
    27  
    28  // ComplexType is the object representing the Python 'complex' type.
    29  var ComplexType = newBasisType("complex", reflect.TypeOf(Complex{}), toComplexUnsafe, ObjectType)
    30  
    31  // Complex represents Python 'complex' objects.
    32  type Complex struct {
    33  	Object
    34  	value complex128
    35  }
    36  
    37  // NewComplex returns a new Complex holding the given complex value.
    38  func NewComplex(value complex128) *Complex {
    39  	return &Complex{Object{typ: ComplexType}, value}
    40  }
    41  
    42  func toComplexUnsafe(o *Object) *Complex {
    43  	return (*Complex)(o.toPointer())
    44  }
    45  
    46  // ToObject upcasts c to an Object.
    47  func (c *Complex) ToObject() *Object {
    48  	return &c.Object
    49  }
    50  
    51  // Value returns the underlying complex value held by c.
    52  func (c *Complex) Value() complex128 {
    53  	return c.value
    54  }
    55  
    56  func complexAbs(f *Frame, o *Object) (*Object, *BaseException) {
    57  	c := toComplexUnsafe(o).Value()
    58  	return NewFloat(cmplx.Abs(c)).ToObject(), nil
    59  }
    60  
    61  func complexAdd(f *Frame, v, w *Object) (*Object, *BaseException) {
    62  	return complexArithmeticOp(f, "__add__", v, w, func(lhs, rhs complex128) complex128 {
    63  		return lhs + rhs
    64  	})
    65  }
    66  
    67  func complexCompareNotSupported(f *Frame, v, w *Object) (*Object, *BaseException) {
    68  	if w.isInstance(IntType) || w.isInstance(LongType) || w.isInstance(FloatType) || w.isInstance(ComplexType) {
    69  		return nil, f.RaiseType(TypeErrorType, "no ordering relation is defined for complex numbers")
    70  	}
    71  	return NotImplemented, nil
    72  }
    73  
    74  func complexComplex(f *Frame, o *Object) (*Object, *BaseException) {
    75  	return o, nil
    76  }
    77  
    78  func complexDiv(f *Frame, v, w *Object) (*Object, *BaseException) {
    79  	return complexDivModOp(f, "__div__", v, w, func(v, w complex128) (complex128, bool) {
    80  		if w == 0 {
    81  			return 0, false
    82  		}
    83  		return v / w, true
    84  	})
    85  }
    86  
    87  func complexDivMod(f *Frame, v, w *Object) (*Object, *BaseException) {
    88  	return complexDivAndModOp(f, "__divmod__", v, w, func(v, w complex128) (complex128, complex128, bool) {
    89  		if w == 0 {
    90  			return 0, 0, false
    91  		}
    92  		return complexFloorDivOp(v, w), complexModOp(v, w), true
    93  	})
    94  }
    95  
    96  func complexEq(f *Frame, v, w *Object) (*Object, *BaseException) {
    97  	e, ok := complexCompare(toComplexUnsafe(v), w)
    98  	if !ok {
    99  		return NotImplemented, nil
   100  	}
   101  	return GetBool(e).ToObject(), nil
   102  }
   103  
   104  func complexFloorDiv(f *Frame, v, w *Object) (*Object, *BaseException) {
   105  	return complexDivModOp(f, "__floordiv__", v, w, func(v, w complex128) (complex128, bool) {
   106  		if w == 0 {
   107  			return 0, false
   108  		}
   109  		return complexFloorDivOp(v, w), true
   110  	})
   111  }
   112  
   113  func complexHash(f *Frame, o *Object) (*Object, *BaseException) {
   114  	v := toComplexUnsafe(o).Value()
   115  	hashCombined := hashFloat(real(v)) + 1000003*hashFloat(imag(v))
   116  	if hashCombined == -1 {
   117  		hashCombined = -2
   118  	}
   119  	return NewInt(hashCombined).ToObject(), nil
   120  }
   121  
   122  func complexMod(f *Frame, v, w *Object) (*Object, *BaseException) {
   123  	return complexDivModOp(f, "__mod__", v, w, func(v, w complex128) (complex128, bool) {
   124  		if w == 0 {
   125  			return 0, false
   126  		}
   127  		return complexModOp(v, w), true
   128  	})
   129  }
   130  
   131  func complexMul(f *Frame, v, w *Object) (*Object, *BaseException) {
   132  	return complexArithmeticOp(f, "__mul__", v, w, func(lhs, rhs complex128) complex128 {
   133  		return lhs * rhs
   134  	})
   135  }
   136  
   137  func complexNE(f *Frame, v, w *Object) (*Object, *BaseException) {
   138  	e, ok := complexCompare(toComplexUnsafe(v), w)
   139  	if !ok {
   140  		return NotImplemented, nil
   141  	}
   142  	return GetBool(!e).ToObject(), nil
   143  }
   144  
   145  func complexNeg(f *Frame, o *Object) (*Object, *BaseException) {
   146  	c := toComplexUnsafe(o).Value()
   147  	return NewComplex(-c).ToObject(), nil
   148  }
   149  
   150  func complexNew(f *Frame, t *Type, args Args, _ KWArgs) (*Object, *BaseException) {
   151  	argc := len(args)
   152  	if argc == 0 {
   153  		return newObject(t), nil
   154  	}
   155  	if argc > 2 {
   156  		return nil, f.RaiseType(TypeErrorType, "'__new__' of 'complex' requires at most 2 arguments")
   157  	}
   158  	if t != ComplexType {
   159  		// Allocate a plain complex then copy it's value into an object
   160  		// of the complex subtype.
   161  		x, raised := complexNew(f, ComplexType, args, nil)
   162  		if raised != nil {
   163  			return nil, raised
   164  		}
   165  		result := toComplexUnsafe(newObject(t))
   166  		result.value = toComplexUnsafe(x).Value()
   167  		return result.ToObject(), nil
   168  	}
   169  	if complexSlot := args[0].typ.slots.Complex; complexSlot != nil && argc == 1 {
   170  		c, raised := complexConvert(complexSlot, f, args[0])
   171  		if raised != nil {
   172  			return nil, raised
   173  		}
   174  		return c.ToObject(), nil
   175  	}
   176  	if args[0].isInstance(StrType) {
   177  		if argc > 1 {
   178  			return nil, f.RaiseType(TypeErrorType, "complex() can't take second arg if first is a string")
   179  		}
   180  		s := toStrUnsafe(args[0]).Value()
   181  		result, err := parseComplex(s)
   182  		if err != nil {
   183  			return nil, f.RaiseType(ValueErrorType, "complex() arg is a malformed string")
   184  		}
   185  		return NewComplex(result).ToObject(), nil
   186  	}
   187  	if argc > 1 && args[1].isInstance(StrType) {
   188  		return nil, f.RaiseType(TypeErrorType, "complex() second arg can't be a string")
   189  	}
   190  	cr, raised := complex128Convert(f, args[0])
   191  	if raised != nil {
   192  		return nil, raised
   193  	}
   194  	var ci complex128
   195  	if argc > 1 {
   196  		ci, raised = complex128Convert(f, args[1])
   197  		if raised != nil {
   198  			return nil, raised
   199  		}
   200  	}
   201  
   202  	// Logically it should be enough to return this:
   203  	//  NewComplex(cr + ci*1i).ToObject()
   204  	// But Go complex arithmatic is not satisfying all conditions, for instance:
   205  	//  cr := complex(math.Inf(1), 0)
   206  	//  ci := complex(math.Inf(-1), 0)
   207  	//  fmt.Println(cr + ci*1i)
   208  	// Output is (NaN-Infi), instead of (+Inf-Infi).
   209  	return NewComplex(complex(real(cr)-imag(ci), imag(cr)+real(ci))).ToObject(), nil
   210  }
   211  
   212  func complexNonZero(f *Frame, o *Object) (*Object, *BaseException) {
   213  	return GetBool(toComplexUnsafe(o).Value() != 0).ToObject(), nil
   214  }
   215  
   216  func complexPos(f *Frame, o *Object) (*Object, *BaseException) {
   217  	return o, nil
   218  }
   219  
   220  func complexPow(f *Frame, v, w *Object) (*Object, *BaseException) {
   221  	return complexArithmeticOp(f, "__pow__", v, w, func(lhs, rhs complex128) complex128 {
   222  		return cmplx.Pow(lhs, rhs)
   223  	})
   224  }
   225  
   226  func complexRAdd(f *Frame, v, w *Object) (*Object, *BaseException) {
   227  	return complexArithmeticOp(f, "__radd__", v, w, func(lhs, rhs complex128) complex128 {
   228  		return lhs + rhs
   229  	})
   230  }
   231  
   232  func complexRDiv(f *Frame, v, w *Object) (*Object, *BaseException) {
   233  	return complexDivModOp(f, "__rdiv__", v, w, func(v, w complex128) (complex128, bool) {
   234  		if v == 0 {
   235  			return 0, false
   236  		}
   237  		return w / v, true
   238  	})
   239  }
   240  
   241  func complexRDivMod(f *Frame, v, w *Object) (*Object, *BaseException) {
   242  	return complexDivAndModOp(f, "__rdivmod__", v, w, func(v, w complex128) (complex128, complex128, bool) {
   243  		if v == 0 {
   244  			return 0, 0, false
   245  		}
   246  		return complexFloorDivOp(w, v), complexModOp(w, v), true
   247  	})
   248  }
   249  
   250  func complexRepr(f *Frame, o *Object) (*Object, *BaseException) {
   251  	c := toComplexUnsafe(o).Value()
   252  	rs, is := "", ""
   253  	pre, post := "", ""
   254  	sign := ""
   255  	if real(c) == 0.0 {
   256  		is = strconv.FormatFloat(imag(c), 'g', -1, 64)
   257  	} else {
   258  		pre = "("
   259  		rs = strconv.FormatFloat(real(c), 'g', -1, 64)
   260  		is = strconv.FormatFloat(imag(c), 'g', -1, 64)
   261  		if imag(c) >= 0.0 || math.IsNaN(imag(c)) {
   262  			sign = "+"
   263  		}
   264  		post = ")"
   265  	}
   266  	rs = unsignPositiveInf(strings.ToLower(rs))
   267  	is = unsignPositiveInf(strings.ToLower(is))
   268  	return NewStr(fmt.Sprintf("%s%s%s%sj%s", pre, rs, sign, is, post)).ToObject(), nil
   269  }
   270  
   271  func complexRFloorDiv(f *Frame, v, w *Object) (*Object, *BaseException) {
   272  	return complexDivModOp(f, "__rfloordiv__", v, w, func(v, w complex128) (complex128, bool) {
   273  		if v == 0 {
   274  			return 0, false
   275  		}
   276  		return complexFloorDivOp(w, v), true
   277  	})
   278  }
   279  
   280  func complexRMod(f *Frame, v, w *Object) (*Object, *BaseException) {
   281  	return complexDivModOp(f, "__rmod__", v, w, func(v, w complex128) (complex128, bool) {
   282  		if v == 0 {
   283  			return 0, false
   284  		}
   285  		return complexModOp(w, v), true
   286  	})
   287  }
   288  
   289  func complexRMul(f *Frame, v, w *Object) (*Object, *BaseException) {
   290  	return complexArithmeticOp(f, "__rmul__", v, w, func(lhs, rhs complex128) complex128 {
   291  		return rhs * lhs
   292  	})
   293  }
   294  
   295  func complexRPow(f *Frame, v, w *Object) (*Object, *BaseException) {
   296  	return complexArithmeticOp(f, "__rpow__", v, w, func(lhs, rhs complex128) complex128 {
   297  		return cmplx.Pow(rhs, lhs)
   298  	})
   299  }
   300  
   301  func complexRSub(f *Frame, v, w *Object) (*Object, *BaseException) {
   302  	return complexArithmeticOp(f, "__rsub__", v, w, func(lhs, rhs complex128) complex128 {
   303  		return rhs - lhs
   304  	})
   305  }
   306  
   307  func complexSub(f *Frame, v, w *Object) (*Object, *BaseException) {
   308  	return complexArithmeticOp(f, "__sub__", v, w, func(lhs, rhs complex128) complex128 {
   309  		return lhs - rhs
   310  	})
   311  }
   312  
   313  func initComplexType(dict map[string]*Object) {
   314  	ComplexType.slots.Abs = &unaryOpSlot{complexAbs}
   315  	ComplexType.slots.Add = &binaryOpSlot{complexAdd}
   316  	ComplexType.slots.Complex = &unaryOpSlot{complexComplex}
   317  	ComplexType.slots.Div = &binaryOpSlot{complexDiv}
   318  	ComplexType.slots.DivMod = &binaryOpSlot{complexDivMod}
   319  	ComplexType.slots.Eq = &binaryOpSlot{complexEq}
   320  	ComplexType.slots.FloorDiv = &binaryOpSlot{complexFloorDiv}
   321  	ComplexType.slots.GE = &binaryOpSlot{complexCompareNotSupported}
   322  	ComplexType.slots.GT = &binaryOpSlot{complexCompareNotSupported}
   323  	ComplexType.slots.Hash = &unaryOpSlot{complexHash}
   324  	ComplexType.slots.LE = &binaryOpSlot{complexCompareNotSupported}
   325  	ComplexType.slots.LT = &binaryOpSlot{complexCompareNotSupported}
   326  	ComplexType.slots.Mod = &binaryOpSlot{complexMod}
   327  	ComplexType.slots.Mul = &binaryOpSlot{complexMul}
   328  	ComplexType.slots.NE = &binaryOpSlot{complexNE}
   329  	ComplexType.slots.Neg = &unaryOpSlot{complexNeg}
   330  	ComplexType.slots.New = &newSlot{complexNew}
   331  	ComplexType.slots.NonZero = &unaryOpSlot{complexNonZero}
   332  	ComplexType.slots.Pos = &unaryOpSlot{complexPos}
   333  	ComplexType.slots.Pow = &binaryOpSlot{complexPow}
   334  	ComplexType.slots.RAdd = &binaryOpSlot{complexRAdd}
   335  	ComplexType.slots.RDiv = &binaryOpSlot{complexRDiv}
   336  	ComplexType.slots.RDivMod = &binaryOpSlot{complexRDivMod}
   337  	ComplexType.slots.RFloorDiv = &binaryOpSlot{complexRFloorDiv}
   338  	ComplexType.slots.Repr = &unaryOpSlot{complexRepr}
   339  	ComplexType.slots.RMod = &binaryOpSlot{complexRMod}
   340  	ComplexType.slots.RMul = &binaryOpSlot{complexRMul}
   341  	ComplexType.slots.RPow = &binaryOpSlot{complexRPow}
   342  	ComplexType.slots.RSub = &binaryOpSlot{complexRSub}
   343  	ComplexType.slots.Sub = &binaryOpSlot{complexSub}
   344  }
   345  
   346  func complex128Convert(f *Frame, o *Object) (complex128, *BaseException) {
   347  	if complexSlot := o.typ.slots.Complex; complexSlot != nil {
   348  		c, raised := complexConvert(complexSlot, f, o)
   349  		if raised != nil {
   350  			return complex(0, 0), raised
   351  		}
   352  		return c.Value(), nil
   353  	} else if floatSlot := o.typ.slots.Float; floatSlot != nil {
   354  		result, raised := floatConvert(floatSlot, f, o)
   355  		if raised != nil {
   356  			return complex(0, 0), raised
   357  		}
   358  		return complex(result.Value(), 0), nil
   359  	} else {
   360  		return complex(0, 0), f.RaiseType(TypeErrorType, "complex() argument must be a string or a number")
   361  	}
   362  }
   363  
   364  func complexArithmeticOp(f *Frame, method string, v, w *Object, fun func(v, w complex128) complex128) (*Object, *BaseException) {
   365  	if w.isInstance(ComplexType) {
   366  		return NewComplex(fun(toComplexUnsafe(v).Value(), toComplexUnsafe(w).Value())).ToObject(), nil
   367  	}
   368  
   369  	floatW, ok := floatCoerce(w)
   370  	if !ok {
   371  		if math.IsInf(floatW, 0) {
   372  			return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to float")
   373  		}
   374  		return NotImplemented, nil
   375  	}
   376  	return NewComplex(fun(toComplexUnsafe(v).Value(), complex(floatW, 0))).ToObject(), nil
   377  }
   378  
   379  // complexCoerce will coerce any numeric type to a complex. If all is
   380  // well, it will return the complex128 value, and true (OK). If an overflow
   381  // occurs, it will return either (+Inf, false) or (-Inf, false) depending
   382  // on whether the source value was too large or too small. Note that if the
   383  // source number is an infinite float, the result will be infinite without
   384  // overflow, (+-Inf, true).
   385  // If the input is not a number, it will return (0, false).
   386  func complexCoerce(o *Object) (complex128, bool) {
   387  	if o.isInstance(ComplexType) {
   388  		return toComplexUnsafe(o).Value(), true
   389  	}
   390  	floatO, ok := floatCoerce(o)
   391  	if !ok {
   392  		if math.IsInf(floatO, 0) {
   393  			return complex(floatO, 0.0), false
   394  		}
   395  		return 0, false
   396  	}
   397  	return complex(floatO, 0.0), true
   398  }
   399  
   400  func complexCompare(v *Complex, w *Object) (bool, bool) {
   401  	lhsr := real(v.Value())
   402  	rhs, ok := complexCoerce(w)
   403  	if !ok {
   404  		return false, false
   405  	}
   406  	return lhsr == real(rhs) && imag(v.Value()) == imag(rhs), true
   407  }
   408  
   409  func complexConvert(complexSlot *unaryOpSlot, f *Frame, o *Object) (*Complex, *BaseException) {
   410  	result, raised := complexSlot.Fn(f, o)
   411  	if raised != nil {
   412  		return nil, raised
   413  	}
   414  	if !result.isInstance(ComplexType) {
   415  		exc := fmt.Sprintf("__complex__ returned non-complex (type %s)", result.typ.Name())
   416  		return nil, f.RaiseType(TypeErrorType, exc)
   417  	}
   418  	return toComplexUnsafe(result), nil
   419  }
   420  
   421  func complexDivModOp(f *Frame, method string, v, w *Object, fun func(v, w complex128) (complex128, bool)) (*Object, *BaseException) {
   422  	complexW, ok := complexCoerce(w)
   423  	if !ok {
   424  		if cmplx.IsInf(complexW) {
   425  			return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to complex")
   426  		}
   427  		return NotImplemented, nil
   428  	}
   429  	x, ok := fun(toComplexUnsafe(v).Value(), complexW)
   430  	if !ok {
   431  		return nil, f.RaiseType(ZeroDivisionErrorType, "complex division or modulo by zero")
   432  	}
   433  	return NewComplex(x).ToObject(), nil
   434  }
   435  
   436  func complexDivAndModOp(f *Frame, method string, v, w *Object, fun func(v, w complex128) (complex128, complex128, bool)) (*Object, *BaseException) {
   437  	complexW, ok := complexCoerce(w)
   438  	if !ok {
   439  		if cmplx.IsInf(complexW) {
   440  			return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to complex")
   441  		}
   442  		return NotImplemented, nil
   443  	}
   444  	q, m, ok := fun(toComplexUnsafe(v).Value(), complexW)
   445  	if !ok {
   446  		return nil, f.RaiseType(ZeroDivisionErrorType, "complex division or modulo by zero")
   447  	}
   448  	return NewTuple2(NewComplex(q).ToObject(), NewComplex(m).ToObject()).ToObject(), nil
   449  }
   450  
   451  func complexFloorDivOp(v, w complex128) complex128 {
   452  	return complex(math.Floor(real(v/w)), 0)
   453  }
   454  
   455  func complexModOp(v, w complex128) complex128 {
   456  	return v - complexFloorDivOp(v, w)*w
   457  }
   458  
   459  const (
   460  	blank = iota
   461  	real1
   462  	imag1
   463  	real2
   464  	sign2
   465  	imag3
   466  	real4
   467  	sign5
   468  	onlyJ
   469  )
   470  
   471  // ParseComplex converts the string s to a complex number.
   472  // If string is well-formed (one of these forms: <float>, <float>j,
   473  // <float><signed-float>j, <float><sign>j, <sign>j or j, where <float> is
   474  // any numeric string that's acceptable by strconv.ParseFloat(s, 64)),
   475  // ParseComplex returns the respective complex128 number.
   476  func parseComplex(s string) (complex128, error) {
   477  	c := strings.Count(s, "(")
   478  	if (c > 1) || (c == 1 && strings.Count(s, ")") != 1) {
   479  		return complex(0, 0), errors.New("Malformed complex string, more than one matching parantheses")
   480  	}
   481  	ts := strings.TrimSpace(s)
   482  	ts = strings.Trim(ts, "()")
   483  	ts = strings.TrimSpace(ts)
   484  	re := `(?i)(?:(?:(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[Ee][+-]?\d+)?)|(?:infinity)|(?:nan)|(?:inf))`
   485  	fre := `[-+]?` + re
   486  	sre := `[-+]` + re
   487  	fsfj := `(?:(?P<real1>` + fre + `)(?P<imag1>` + sre + `)j)`
   488  	fsj := `(?:(?P<real2>` + fre + `)(?P<sign2>[-+])j)`
   489  	fj := `(?P<imag3>` + fre + `)j`
   490  	f := `(?P<real4>` + fre + `)`
   491  	sj := `(?P<sign5>[-+])j`
   492  	j := `(?P<onlyJ>j)`
   493  	r := regexp.MustCompile(`^(?:` + fsfj + `|` + fsj + `|` + fj + `|` + f + `|` + sj + `|` + j + `)$`)
   494  	subs := r.FindStringSubmatch(ts)
   495  	if subs == nil {
   496  		return complex(0, 0), errors.New("Malformed complex string, no mathing pattern found")
   497  	}
   498  	if subs[real1] != "" && subs[imag1] != "" {
   499  		r, _ := strconv.ParseFloat(unsignNaN(subs[real1]), 64)
   500  		i, err := strconv.ParseFloat(unsignNaN(subs[imag1]), 64)
   501  		return complex(r, i), err
   502  	}
   503  	if subs[real2] != "" && subs[sign2] != "" {
   504  		r, err := strconv.ParseFloat(unsignNaN(subs[real2]), 64)
   505  		if subs[sign2] == "-" {
   506  			return complex(r, -1), err
   507  		}
   508  		return complex(r, 1), err
   509  	}
   510  	if subs[imag3] != "" {
   511  		i, err := strconv.ParseFloat(unsignNaN(subs[imag3]), 64)
   512  		return complex(0, i), err
   513  	}
   514  	if subs[real4] != "" {
   515  		r, err := strconv.ParseFloat(unsignNaN(subs[real4]), 64)
   516  		return complex(r, 0), err
   517  	}
   518  	if subs[sign5] != "" {
   519  		if subs[sign5] == "-" {
   520  			return complex(0, -1), nil
   521  		}
   522  		return complex(0, 1), nil
   523  	}
   524  	if subs[onlyJ] != "" {
   525  		return complex(0, 1), nil
   526  	}
   527  	return complex(0, 0), errors.New("Malformed complex string")
   528  }
   529  
   530  func unsignNaN(s string) string {
   531  	ls := strings.ToLower(s)
   532  	if ls == "-nan" || ls == "+nan" {
   533  		return "nan"
   534  	}
   535  	return s
   536  }