gobot.io/x/gobot/v2@v2.1.0/examples/tello_facetracker.go (about)

     1  //go:build example
     2  // +build example
     3  
     4  //
     5  // Do not build by default.
     6  
     7  /*
     8  You must have ffmpeg and OpenCV installed in order to run this code. It will connect to the Tello
     9  and then open a window using OpenCV showing the streaming video.
    10  
    11  How to run
    12  
    13  	go run examples/tello_facetracker.go ~/Downloads/res10_300x300_ssd_iter_140000.caffemodel ~/Development/opencv/samples/dnn/face_detector/deploy.prototxt
    14  
    15  You can find download the weight via https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel
    16  And you can find protofile in OpenCV samples directory
    17  */
    18  
    19  package main
    20  
    21  import (
    22  	"fmt"
    23  	"image"
    24  	"image/color"
    25  	"io"
    26  	"math"
    27  	"os"
    28  	"os/exec"
    29  	"strconv"
    30  	"sync/atomic"
    31  	"time"
    32  
    33  	"gobot.io/x/gobot/v2"
    34  	"gobot.io/x/gobot/v2/platforms/dji/tello"
    35  	"gobot.io/x/gobot/v2/platforms/joystick"
    36  	"gocv.io/x/gocv"
    37  )
    38  
    39  type pair struct {
    40  	x float64
    41  	y float64
    42  }
    43  
    44  const (
    45  	frameX    = 400
    46  	frameY    = 300
    47  	frameSize = frameX * frameY * 3
    48  	offset    = 32767.0
    49  )
    50  
    51  var (
    52  	// ffmpeg command to decode video stream from drone
    53  	ffmpeg = exec.Command("ffmpeg", "-hwaccel", "auto", "-hwaccel_device", "opencl", "-i", "pipe:0",
    54  		"-nostats", "-flags", "low_delay", "-probesize", "32", "-fflags", "nobuffer+fastseek+flush_packets", "-analyzeduration", "0", "-af", "aresample=async=1:min_comp=0.1:first_pts=0",
    55  		"-pix_fmt", "bgr24", "-s", strconv.Itoa(frameX)+"x"+strconv.Itoa(frameY), "-f", "rawvideo", "pipe:1")
    56  	ffmpegIn, _  = ffmpeg.StdinPipe()
    57  	ffmpegOut, _ = ffmpeg.StdoutPipe()
    58  
    59  	// gocv
    60  	window = gocv.NewWindow("Tello")
    61  	net    *gocv.Net
    62  	green  = color.RGBA{0, 255, 0, 0}
    63  
    64  	// tracking
    65  	tracking                 = false
    66  	detected                 = false
    67  	detectSize               = false
    68  	distTolerance            = 0.05 * dist(0, 0, frameX, frameY)
    69  	refDistance              float64
    70  	left, top, right, bottom float64
    71  
    72  	// drone
    73  	drone      = tello.NewDriver("8890")
    74  	flightData *tello.FlightData
    75  
    76  	// joystick
    77  	joyAdaptor                   = joystick.NewAdaptor()
    78  	stick                        = joystick.NewDriver(joyAdaptor, "dualshock4")
    79  	leftX, leftY, rightX, rightY atomic.Value
    80  )
    81  
    82  func init() {
    83  	leftX.Store(float64(0.0))
    84  	leftY.Store(float64(0.0))
    85  	rightX.Store(float64(0.0))
    86  	rightY.Store(float64(0.0))
    87  
    88  	// process drone events in separate goroutine for concurrency
    89  	go func() {
    90  		handleJoystick()
    91  
    92  		if err := ffmpeg.Start(); err != nil {
    93  			fmt.Println(err)
    94  			return
    95  		}
    96  
    97  		drone.On(tello.FlightDataEvent, func(data interface{}) {
    98  			// TODO: protect flight data from race condition
    99  			flightData = data.(*tello.FlightData)
   100  		})
   101  
   102  		drone.On(tello.ConnectedEvent, func(data interface{}) {
   103  			fmt.Println("Connected")
   104  			drone.StartVideo()
   105  			drone.SetVideoEncoderRate(tello.VideoBitRateAuto)
   106  			drone.SetExposure(0)
   107  			gobot.Every(100*time.Millisecond, func() {
   108  				drone.StartVideo()
   109  			})
   110  		})
   111  
   112  		drone.On(tello.VideoFrameEvent, func(data interface{}) {
   113  			pkt := data.([]byte)
   114  			if _, err := ffmpegIn.Write(pkt); err != nil {
   115  				fmt.Println(err)
   116  			}
   117  		})
   118  
   119  		robot := gobot.NewRobot("tello",
   120  			[]gobot.Connection{joyAdaptor},
   121  			[]gobot.Device{drone, stick},
   122  		)
   123  
   124  		robot.Start()
   125  	}()
   126  }
   127  
   128  func main() {
   129  	if len(os.Args) < 5 {
   130  		fmt.Println("How to run:\ngo run facetracker.go [model] [config] ([backend] [device])")
   131  		return
   132  	}
   133  
   134  	model := os.Args[1]
   135  	config := os.Args[2]
   136  	backend := gocv.NetBackendDefault
   137  	if len(os.Args) > 3 {
   138  		backend = gocv.ParseNetBackend(os.Args[3])
   139  	}
   140  
   141  	target := gocv.NetTargetCPU
   142  	if len(os.Args) > 4 {
   143  		target = gocv.ParseNetTarget(os.Args[4])
   144  	}
   145  
   146  	n := gocv.ReadNet(model, config)
   147  	if n.Empty() {
   148  		fmt.Printf("Error reading network model from : %v %v\n", model, config)
   149  		return
   150  	}
   151  	net = &n
   152  	defer net.Close()
   153  	net.SetPreferableBackend(gocv.NetBackendType(backend))
   154  	net.SetPreferableTarget(gocv.NetTargetType(target))
   155  
   156  	for {
   157  		// get next frame from stream
   158  		buf := make([]byte, frameSize)
   159  		if _, err := io.ReadFull(ffmpegOut, buf); err != nil {
   160  			fmt.Println(err)
   161  			continue
   162  		}
   163  		img, _ := gocv.NewMatFromBytes(frameY, frameX, gocv.MatTypeCV8UC3, buf)
   164  		if img.Empty() {
   165  			continue
   166  		}
   167  
   168  		trackFace(&img)
   169  
   170  		window.IMShow(img)
   171  		if window.WaitKey(10) >= 0 {
   172  			break
   173  		}
   174  	}
   175  }
   176  
   177  func trackFace(frame *gocv.Mat) {
   178  	W := float64(frame.Cols())
   179  	H := float64(frame.Rows())
   180  
   181  	blob := gocv.BlobFromImage(*frame, 1.0, image.Pt(300, 300), gocv.NewScalar(104, 177, 123, 0), false, false)
   182  	defer blob.Close()
   183  
   184  	net.SetInput(blob, "data")
   185  
   186  	detBlob := net.Forward("detection_out")
   187  	defer detBlob.Close()
   188  
   189  	detections := gocv.GetBlobChannel(detBlob, 0, 0)
   190  	defer detections.Close()
   191  
   192  	for r := 0; r < detections.Rows(); r++ {
   193  		confidence := detections.GetFloatAt(r, 2)
   194  		if confidence < 0.5 {
   195  			continue
   196  		}
   197  
   198  		left = float64(detections.GetFloatAt(r, 3)) * W
   199  		top = float64(detections.GetFloatAt(r, 4)) * H
   200  		right = float64(detections.GetFloatAt(r, 5)) * W
   201  		bottom = float64(detections.GetFloatAt(r, 6)) * H
   202  
   203  		left = math.Min(math.Max(0.0, left), W-1.0)
   204  		right = math.Min(math.Max(0.0, right), W-1.0)
   205  		bottom = math.Min(math.Max(0.0, bottom), H-1.0)
   206  		top = math.Min(math.Max(0.0, top), H-1.0)
   207  
   208  		detected = true
   209  		rect := image.Rect(int(left), int(top), int(right), int(bottom))
   210  		gocv.Rectangle(frame, rect, green, 3)
   211  	}
   212  
   213  	if !tracking || !detected {
   214  		return
   215  	}
   216  
   217  	if detectSize {
   218  		detectSize = false
   219  		refDistance = dist(left, top, right, bottom)
   220  	}
   221  
   222  	distance := dist(left, top, right, bottom)
   223  
   224  	// x axis
   225  	switch {
   226  	case right < W/2:
   227  		drone.CounterClockwise(50)
   228  	case left > W/2:
   229  		drone.Clockwise(50)
   230  	default:
   231  		drone.Clockwise(0)
   232  	}
   233  
   234  	// y axis
   235  	switch {
   236  	case top < H/10:
   237  		drone.Up(25)
   238  	case bottom > H-H/10:
   239  		drone.Down(25)
   240  	default:
   241  		drone.Up(0)
   242  	}
   243  
   244  	// z axis
   245  	switch {
   246  	case distance < refDistance-distTolerance:
   247  		drone.Forward(20)
   248  	case distance > refDistance+distTolerance:
   249  		drone.Backward(20)
   250  	default:
   251  		drone.Forward(0)
   252  	}
   253  }
   254  
   255  func dist(x1, y1, x2, y2 float64) float64 {
   256  	return math.Sqrt((x2-x1)*(x2-x1) + (y2-y1)*(y2-y1))
   257  }
   258  
   259  func handleJoystick() {
   260  	stick.On(joystick.CirclePress, func(data interface{}) {
   261  		drone.Forward(0)
   262  		drone.Up(0)
   263  		drone.Clockwise(0)
   264  		tracking = !tracking
   265  		if tracking {
   266  			detectSize = true
   267  			println("tracking")
   268  		} else {
   269  			detectSize = false
   270  			println("not tracking")
   271  		}
   272  	})
   273  	stick.On(joystick.SquarePress, func(data interface{}) {
   274  		fmt.Println("battery:", flightData.BatteryPercentage)
   275  	})
   276  	stick.On(joystick.TrianglePress, func(data interface{}) {
   277  		drone.TakeOff()
   278  		println("Takeoff")
   279  	})
   280  	stick.On(joystick.XPress, func(data interface{}) {
   281  		drone.Land()
   282  		println("Land")
   283  	})
   284  	stick.On(joystick.LeftX, func(data interface{}) {
   285  		val := float64(data.(int16))
   286  		leftX.Store(val)
   287  	})
   288  
   289  	stick.On(joystick.LeftY, func(data interface{}) {
   290  		val := float64(data.(int16))
   291  		leftY.Store(val)
   292  	})
   293  
   294  	stick.On(joystick.RightX, func(data interface{}) {
   295  		val := float64(data.(int16))
   296  		rightX.Store(val)
   297  	})
   298  
   299  	stick.On(joystick.RightY, func(data interface{}) {
   300  		val := float64(data.(int16))
   301  		rightY.Store(val)
   302  	})
   303  	gobot.Every(50*time.Millisecond, func() {
   304  		rightStick := getRightStick()
   305  
   306  		switch {
   307  		case rightStick.y < -10:
   308  			drone.Forward(tello.ValidatePitch(rightStick.y, offset))
   309  		case rightStick.y > 10:
   310  			drone.Backward(tello.ValidatePitch(rightStick.y, offset))
   311  		default:
   312  			drone.Forward(0)
   313  		}
   314  
   315  		switch {
   316  		case rightStick.x > 10:
   317  			drone.Right(tello.ValidatePitch(rightStick.x, offset))
   318  		case rightStick.x < -10:
   319  			drone.Left(tello.ValidatePitch(rightStick.x, offset))
   320  		default:
   321  			drone.Right(0)
   322  		}
   323  	})
   324  
   325  	gobot.Every(50*time.Millisecond, func() {
   326  		leftStick := getLeftStick()
   327  		switch {
   328  		case leftStick.y < -10:
   329  			drone.Up(tello.ValidatePitch(leftStick.y, offset))
   330  		case leftStick.y > 10:
   331  			drone.Down(tello.ValidatePitch(leftStick.y, offset))
   332  		default:
   333  			drone.Up(0)
   334  		}
   335  
   336  		switch {
   337  		case leftStick.x > 20:
   338  			drone.Clockwise(tello.ValidatePitch(leftStick.x, offset))
   339  		case leftStick.x < -20:
   340  			drone.CounterClockwise(tello.ValidatePitch(leftStick.x, offset))
   341  		default:
   342  			drone.Clockwise(0)
   343  		}
   344  	})
   345  }
   346  
   347  func getLeftStick() pair {
   348  	s := pair{x: 0, y: 0}
   349  	s.x = leftX.Load().(float64)
   350  	s.y = leftY.Load().(float64)
   351  	return s
   352  }
   353  
   354  func getRightStick() pair {
   355  	s := pair{x: 0, y: 0}
   356  	s.x = rightX.Load().(float64)
   357  	s.y = rightY.Load().(float64)
   358  	return s
   359  }