github.com/ryicoh/apery-graphql@v0.0.0-20210919090814-a8c219904bee/pkg/apery/client.go (about)

     1  package apery
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"os/exec"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  )
    14  
    15  type (
    16  	AperyClient interface {
    17  		Evaluate(ctx context.Context, sfen string, moves []string, timeout time.Duration) (result *EvaluationResult, err error)
    18  	}
    19  
    20  	aperyClient struct {
    21  		bin string
    22  	}
    23  
    24  	EvaluationResult struct {
    25  		Value    int
    26  		Nodes    int
    27  		Depth    int
    28  		Bestmove string
    29  		Pv       []string
    30  	}
    31  )
    32  
    33  func NewAperyClient(bin string) AperyClient {
    34  	return &aperyClient{bin}
    35  }
    36  
    37  func (a *aperyClient) Evaluate(ctx context.Context, sfen string, moves []string, timeout time.Duration) (result *EvaluationResult, err error) {
    38  	cmd := exec.CommandContext(ctx, a.bin)
    39  	var stdout, stderr bytes.Buffer
    40  	cmd.Stdout = &stdout
    41  	cmd.Stderr = &stderr
    42  
    43  	stdin, err := cmd.StdinPipe()
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  	defer stdin.Close()
    48  
    49  	if err := cmd.Start(); err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	if err := a.isReady(stdin, &stdout); err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	if err := a.setPosition(stdin, &stdout, sfen, moves); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	if err := a._go(stdin); err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	time.Sleep(timeout)
    66  
    67  	if err := a.stop(stdin); err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	result, err = a.getResult(&stdout)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	return result, nil
    77  }
    78  
    79  func (a *aperyClient) isReady(stdin io.Writer, stdout io.Reader) error {
    80  	if _, err := io.WriteString(stdin, "isready\n"); err != nil {
    81  		return err
    82  	}
    83  
    84  	res, err := a.waitResponse(stdout, 1000, 100*time.Millisecond)
    85  	if err != nil || res != "readyok\n" {
    86  		return fmt.Errorf("apery がisreadyに対して%d秒以内にreadyokを返しません", 10)
    87  	}
    88  	return nil
    89  }
    90  
    91  func (a *aperyClient) waitResponse(stdout io.Reader, attemptLimit int, interval time.Duration) (string, error) {
    92  	for i := 0; i < attemptLimit; i++ {
    93  		bytes, err := io.ReadAll(stdout)
    94  		if err != nil {
    95  			return "", err
    96  		}
    97  
    98  		if len(bytes) == 0 {
    99  			time.Sleep(interval)
   100  			continue
   101  		}
   102  
   103  		return string(bytes), nil
   104  	}
   105  
   106  	return "", errors.New("attempt limit exceeded")
   107  }
   108  
   109  func (a *aperyClient) setPosition(stdin io.Writer, stdout io.Reader, sfen string, moves []string) error {
   110  	if _, err := io.WriteString(
   111  		stdin, fmt.Sprintf("position sfen %s moves %s\n", sfen, strings.Join(moves, " "))); err != nil {
   112  		return err
   113  	}
   114  	time.Sleep(100 * time.Millisecond)
   115  	bytes, err := io.ReadAll(stdout)
   116  	if err != nil {
   117  		return err
   118  	}
   119  
   120  	if len(bytes) != 0 {
   121  		return errors.New(string(bytes))
   122  	}
   123  
   124  	return nil
   125  }
   126  
   127  func (a *aperyClient) _go(stdin io.Writer) error {
   128  	if _, err := io.WriteString(stdin, "go\n"); err != nil {
   129  		return err
   130  	}
   131  	return nil
   132  }
   133  
   134  func (a *aperyClient) stop(stdin io.Writer) error {
   135  	if _, err := io.WriteString(stdin, "stop\n"); err != nil {
   136  		return err
   137  	}
   138  	return nil
   139  }
   140  
   141  func (a *aperyClient) getResult(stdout io.Reader) (result *EvaluationResult, err error) {
   142  	result = new(EvaluationResult)
   143  
   144  	res, err := a.waitResponse(stdout, 10, 100*time.Millisecond)
   145  	logs := strings.TrimRight(string(res), "\n")
   146  	lines := strings.Split(logs, "\n")
   147  
   148  	if err != nil || !strings.Contains(logs, "bestmove") || len(lines) <= 2 {
   149  		return nil, fmt.Errorf("bestmoveが得られません")
   150  	}
   151  
   152  	bestmoveline := lines[len(lines)-1]
   153  	result.Bestmove = strings.Split(bestmoveline, " ")[1]
   154  
   155  	lastInfoLine := lines[len(lines)-2]
   156  	lastInfoLineParts := strings.Split(lastInfoLine, " ")
   157  	result.Pv = make([]string, 0, 10)
   158  
   159  	for i, part := range lastInfoLineParts {
   160  		if part == "cp" {
   161  			result.Value, err = strconv.Atoi(lastInfoLineParts[i+1])
   162  			if err != nil {
   163  				return nil, err
   164  			}
   165  		}
   166  
   167  		if part == "depth" {
   168  			result.Depth, err = strconv.Atoi(lastInfoLineParts[i+1])
   169  			if err != nil {
   170  				return nil, err
   171  			}
   172  		}
   173  
   174  		if part == "nodes" {
   175  			result.Nodes, err = strconv.Atoi(lastInfoLineParts[i+1])
   176  			if err != nil {
   177  				return nil, err
   178  			}
   179  		}
   180  
   181  		if part == "pv" {
   182  			for _, p := range lastInfoLineParts[i+1:] {
   183  				result.Pv = append(result.Pv, p)
   184  			}
   185  		}
   186  	}
   187  
   188  	return result, nil
   189  }