github.com/qri-io/qri@v0.10.1-0.20220104210721-c771715036cb/transform/staticlark/control_flow_test.go (about) 1 package staticlark 2 3 import ( 4 "github.com/google/go-cmp/cmp" 5 "go.starlark.net/syntax" 6 "testing" 7 ) 8 9 func TestControlFlowIf(t *testing.T) { 10 // a simple function with an if/else 11 funcmap := mustReadScriptFunctionMap(t, "testdata/some_funcs.star") 12 cf, err := newControlFlowFromFunc(funcmap["use_branch"]) 13 if err != nil { 14 t.Fatal(err) 15 } 16 17 expect := `0: [set! a 1] 18 [set! b 2] 19 out: 1 20 1: [if [< a b]] 21 out: 2,3, join: 4 22 2: [set! c [+ b 1]] 23 out: 4 24 3: [set! c [+ a 1]] 25 out: 4 26 4: [print [% '%d' c]] 27 out: - 28 ` 29 actual := cf.stringify() 30 if diff := cmp.Diff(expect, actual); diff != "" { 31 t.Errorf("mismatch (-want +got):\n%s", diff) 32 } 33 34 // a function with if, but no else 35 cf, err = newControlFlowFromFunc(funcmap["branch_no_else"]) 36 if err != nil { 37 t.Fatal(err) 38 } 39 40 expect = `0: [set! a 1] 41 [set! b 2] 42 out: 1 43 1: [if [< a b]] 44 out: 2,3, join: 3 45 2: [set! c [+ b 1]] 46 [print [% '%d' c]] 47 out: 3 48 3: [print [% '%d' b]] 49 out: - 50 ` 51 52 actual = cf.stringify() 53 if diff := cmp.Diff(expect, actual); diff != "" { 54 t.Errorf("mismatch (-want +got):\n%s", diff) 55 } 56 57 // a function with if nested within an if 58 cf, err = newControlFlowFromFunc(funcmap["branch_nested"]) 59 if err != nil { 60 t.Fatal(err) 61 } 62 63 expect = `0: [set! a 1] 64 [set! b 2] 65 out: 1 66 1: [if [< a b]] 67 out: 2,5, join: 6 68 2: [set! c [+ b 1]] 69 [set! d a] 70 out: 3 71 3: [if [> d c]] 72 out: 4,6, join: 6 73 4: [set! c [+ d 2]] 74 out: 6 75 5: [set! c [+ a 1]] 76 [print c] 77 [set! e [+ c 2]] 78 out: 6 79 6: [print [% '%d' e]] 80 out: - 81 ` 82 actual = cf.stringify() 83 if diff := cmp.Diff(expect, actual); diff != "" { 84 t.Errorf("mismatch (-want +got):\n%s", diff) 85 } 86 87 // a function with if and elif and else 88 funcmap = mustReadScriptFunctionMap(t, "testdata/some_funcs.star") 89 cf, err = newControlFlowFromFunc(funcmap["branch_elses"]) 90 if err != nil { 91 t.Fatal(err) 92 } 93 94 expect = `0: [set! a 1] 95 [set! b 2] 96 out: 1 97 1: [if [< a b]] 98 out: 2,8, join: 9 99 2: [set! c [+ b 1]] 100 out: 3 101 3: [if [< c 1]] 102 out: 4,5, join: 9 103 4: [print 'small'] 104 out: 9 105 5: [if [< c 5]] 106 out: 6,7, join: 9 107 6: [print 'medium'] 108 out: 9 109 7: [print 'large'] 110 out: 9 111 8: [print 'ok'] 112 out: 9 113 9: [print 'done'] 114 out: - 115 ` 116 actual = cf.stringify() 117 if diff := cmp.Diff(expect, actual); diff != "" { 118 t.Errorf("mismatch (-want +got):\n%s", diff) 119 } 120 121 // this function has another statement (block 8) which ensures 122 // that the inner if statement is completely contained in the outer 123 funcmap = mustReadScriptFunctionMap(t, "testdata/some_funcs.star") 124 cf, err = newControlFlowFromFunc(funcmap["branch_elses_contained"]) 125 if err != nil { 126 t.Fatal(err) 127 } 128 129 expect = `0: [set! a 1] 130 [set! b 2] 131 out: 1 132 1: [if [< a b]] 133 out: 2,9, join: 10 134 2: [set! c [+ b 1]] 135 out: 3 136 3: [if [< c 1]] 137 out: 4,5, join: 8 138 4: [print 'small'] 139 out: 8 140 5: [if [< c 5]] 141 out: 6,7, join: 8 142 6: [print 'medium'] 143 out: 8 144 7: [print 'large'] 145 out: 8 146 8: [print 'sized'] 147 out: 10 148 9: [print 'ok'] 149 out: 10 150 10: [print 'done'] 151 out: - 152 ` 153 actual = cf.stringify() 154 if diff := cmp.Diff(expect, actual); diff != "" { 155 t.Errorf("mismatch (-want +got):\n%s", diff) 156 } 157 } 158 159 func TestControlFlowSimpleLoop(t *testing.T) { 160 funcmap := mustReadScriptFunctionMap(t, "testdata/loop_funcs.star") 161 162 cf, err := newControlFlowFromFunc(funcmap["stddev"]) 163 if err != nil { 164 t.Fatal(err) 165 } 166 167 expect := `0: [set! total 0] 168 out: 1 169 1: [for x ls] 170 out: 2,3 171 2: [set! total [+= total x]] 172 out: 1 173 3: [set! n [len ls]] 174 [set! mean [/ total n]] 175 [set! result 0] 176 out: 4 177 4: [for x ls] 178 out: 5,6 179 5: [set! diff [- x mean]] 180 [set! result [+= result [* diff diff]]] 181 out: 4 182 6: [set! variance [/ result n]] 183 [return [math.sqrt variance]] 184 out: return 185 ` 186 actual := cf.stringify() 187 if diff := cmp.Diff(expect, actual); diff != "" { 188 t.Errorf("mismatch (-want +got):\n%s", diff) 189 } 190 } 191 192 func TestControlFlowLoopWithBreak(t *testing.T) { 193 funcmap := mustReadScriptFunctionMap(t, "testdata/loop_funcs.star") 194 195 cf, err := newControlFlowFromFunc(funcmap["gcd_debug"]) 196 if err != nil { 197 t.Fatal(err) 198 } 199 200 // TODO(dustmop): `break` should instead be 9 201 expect := `0: [print "gcd starting"] 202 out: 1 203 1: [for n [range 20]] 204 out: 2,9 205 2: [print "gcd a = %d, b = %d" a b] 206 out: 3 207 3: [if [== a b]] 208 out: 4,5 209 4: [print "gcd break at step %d" n] 210 [break] 211 out: break 212 5: [print "still going"] 213 out: 6 214 6: [if [> a b]] 215 out: 7,8 216 7: [set! a [- a b]] 217 out: 1 218 8: [set! b [- b a]] 219 out: 1 220 9: [print "gcd returns %d" a] 221 [return a] 222 out: return 223 ` 224 actual := cf.stringify() 225 if diff := cmp.Diff(expect, actual); diff != "" { 226 t.Errorf("mismatch (-want +got):\n%s", diff) 227 } 228 } 229 230 func mustReadScriptFunctionMap(t *testing.T, filename string) map[string]*funcNode { 231 f, err := syntax.Parse(filename, nil, 0) 232 if err != nil { 233 t.Fatal(err) 234 } 235 // Collect function definitions and top level function calls 236 funcs, _, err := collectFuncDefsTopLevelCalls(f.Stmts) 237 if err != nil { 238 t.Fatal(err) 239 } 240 fmap := make(map[string]*funcNode) 241 for _, f := range funcs { 242 fmap[f.name] = f 243 } 244 return fmap 245 } 246 247 func TestUnitBasic(t *testing.T) { 248 root := unit{ 249 atom: "set!", 250 tail: []*unit{ 251 &unit{atom: "a"}, 252 &unit{atom: "b"}, 253 }, 254 } 255 actual := root.String() 256 expect := `[set! a b]` 257 if diff := cmp.Diff(expect, actual); diff != "" { 258 t.Errorf("mismatch (-want +got):\n%s", diff) 259 } 260 261 actualSrc := root.DataSources() 262 expectSrc := []string{"b"} 263 if diff := cmp.Diff(expectSrc, actualSrc); diff != "" { 264 t.Errorf("sources mismatch (-want +got):\n%s", diff) 265 } 266 }