github.com/ryicoh/apery-graphql@v0.0.0-20210919090814-a8c219904bee/pkg/apery/client.go (about) 1 package apery 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "io" 9 "os/exec" 10 "strconv" 11 "strings" 12 "time" 13 ) 14 15 type ( 16 AperyClient interface { 17 Evaluate(ctx context.Context, sfen string, moves []string, timeout time.Duration) (result *EvaluationResult, err error) 18 } 19 20 aperyClient struct { 21 bin string 22 } 23 24 EvaluationResult struct { 25 Value int 26 Nodes int 27 Depth int 28 Bestmove string 29 Pv []string 30 } 31 ) 32 33 func NewAperyClient(bin string) AperyClient { 34 return &aperyClient{bin} 35 } 36 37 func (a *aperyClient) Evaluate(ctx context.Context, sfen string, moves []string, timeout time.Duration) (result *EvaluationResult, err error) { 38 cmd := exec.CommandContext(ctx, a.bin) 39 var stdout, stderr bytes.Buffer 40 cmd.Stdout = &stdout 41 cmd.Stderr = &stderr 42 43 stdin, err := cmd.StdinPipe() 44 if err != nil { 45 return nil, err 46 } 47 defer stdin.Close() 48 49 if err := cmd.Start(); err != nil { 50 return nil, err 51 } 52 53 if err := a.isReady(stdin, &stdout); err != nil { 54 return nil, err 55 } 56 57 if err := a.setPosition(stdin, &stdout, sfen, moves); err != nil { 58 return nil, err 59 } 60 61 if err := a._go(stdin); err != nil { 62 return nil, err 63 } 64 65 time.Sleep(timeout) 66 67 if err := a.stop(stdin); err != nil { 68 return nil, err 69 } 70 71 result, err = a.getResult(&stdout) 72 if err != nil { 73 return nil, err 74 } 75 76 return result, nil 77 } 78 79 func (a *aperyClient) isReady(stdin io.Writer, stdout io.Reader) error { 80 if _, err := io.WriteString(stdin, "isready\n"); err != nil { 81 return err 82 } 83 84 res, err := a.waitResponse(stdout, 1000, 100*time.Millisecond) 85 if err != nil || res != "readyok\n" { 86 return fmt.Errorf("apery がisreadyに対して%d秒以内にreadyokを返しません", 10) 87 } 88 return nil 89 } 90 91 func (a *aperyClient) waitResponse(stdout io.Reader, attemptLimit int, interval time.Duration) (string, error) { 92 for i := 0; i < attemptLimit; i++ { 93 bytes, err := io.ReadAll(stdout) 94 if err != nil { 95 return "", err 96 } 97 98 if len(bytes) == 0 { 99 time.Sleep(interval) 100 continue 101 } 102 103 return string(bytes), nil 104 } 105 106 return "", errors.New("attempt limit exceeded") 107 } 108 109 func (a *aperyClient) setPosition(stdin io.Writer, stdout io.Reader, sfen string, moves []string) error { 110 if _, err := io.WriteString( 111 stdin, fmt.Sprintf("position sfen %s moves %s\n", sfen, strings.Join(moves, " "))); err != nil { 112 return err 113 } 114 time.Sleep(100 * time.Millisecond) 115 bytes, err := io.ReadAll(stdout) 116 if err != nil { 117 return err 118 } 119 120 if len(bytes) != 0 { 121 return errors.New(string(bytes)) 122 } 123 124 return nil 125 } 126 127 func (a *aperyClient) _go(stdin io.Writer) error { 128 if _, err := io.WriteString(stdin, "go\n"); err != nil { 129 return err 130 } 131 return nil 132 } 133 134 func (a *aperyClient) stop(stdin io.Writer) error { 135 if _, err := io.WriteString(stdin, "stop\n"); err != nil { 136 return err 137 } 138 return nil 139 } 140 141 func (a *aperyClient) getResult(stdout io.Reader) (result *EvaluationResult, err error) { 142 result = new(EvaluationResult) 143 144 res, err := a.waitResponse(stdout, 10, 100*time.Millisecond) 145 logs := strings.TrimRight(string(res), "\n") 146 lines := strings.Split(logs, "\n") 147 148 if err != nil || !strings.Contains(logs, "bestmove") || len(lines) <= 2 { 149 return nil, fmt.Errorf("bestmoveが得られません") 150 } 151 152 bestmoveline := lines[len(lines)-1] 153 result.Bestmove = strings.Split(bestmoveline, " ")[1] 154 155 lastInfoLine := lines[len(lines)-2] 156 lastInfoLineParts := strings.Split(lastInfoLine, " ") 157 result.Pv = make([]string, 0, 10) 158 159 for i, part := range lastInfoLineParts { 160 if part == "cp" { 161 result.Value, err = strconv.Atoi(lastInfoLineParts[i+1]) 162 if err != nil { 163 return nil, err 164 } 165 } 166 167 if part == "depth" { 168 result.Depth, err = strconv.Atoi(lastInfoLineParts[i+1]) 169 if err != nil { 170 return nil, err 171 } 172 } 173 174 if part == "nodes" { 175 result.Nodes, err = strconv.Atoi(lastInfoLineParts[i+1]) 176 if err != nil { 177 return nil, err 178 } 179 } 180 181 if part == "pv" { 182 for _, p := range lastInfoLineParts[i+1:] { 183 result.Pv = append(result.Pv, p) 184 } 185 } 186 } 187 188 return result, nil 189 }