gobot.io/x/gobot@v1.16.0/examples/tello_facetracker.go (about)

     1  // +build example
     2  //
     3  // Do not build by default.
     4  
     5  /*
     6  You must have ffmpeg and OpenCV installed in order to run this code. It will connect to the Tello
     7  and then open a window using OpenCV showing the streaming video.
     8  
     9  How to run
    10  
    11  	go run examples/tello_facetracker.go ~/Downloads/res10_300x300_ssd_iter_140000.caffemodel ~/Development/opencv/samples/dnn/face_detector/deploy.prototxt
    12  
    13  You can find download the weight via https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel
    14  And you can find protofile in OpenCV samples directory
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"fmt"
    21  	"image"
    22  	"image/color"
    23  	"io"
    24  	"math"
    25  	"os"
    26  	"os/exec"
    27  	"strconv"
    28  	"sync/atomic"
    29  	"time"
    30  
    31  	"gobot.io/x/gobot"
    32  	"gobot.io/x/gobot/platforms/dji/tello"
    33  	"gobot.io/x/gobot/platforms/joystick"
    34  	"gocv.io/x/gocv"
    35  )
    36  
    37  type pair struct {
    38  	x float64
    39  	y float64
    40  }
    41  
    42  const (
    43  	frameX    = 400
    44  	frameY    = 300
    45  	frameSize = frameX * frameY * 3
    46  	offset    = 32767.0
    47  )
    48  
    49  var (
    50  	// ffmpeg command to decode video stream from drone
    51  	ffmpeg = exec.Command("ffmpeg", "-hwaccel", "auto", "-hwaccel_device", "opencl", "-i", "pipe:0",
    52  		"-nostats", "-flags", "low_delay", "-probesize", "32", "-fflags", "nobuffer+fastseek+flush_packets", "-analyzeduration", "0", "-af", "aresample=async=1:min_comp=0.1:first_pts=0",
    53  		"-pix_fmt", "bgr24", "-s", strconv.Itoa(frameX)+"x"+strconv.Itoa(frameY), "-f", "rawvideo", "pipe:1")
    54  	ffmpegIn, _  = ffmpeg.StdinPipe()
    55  	ffmpegOut, _ = ffmpeg.StdoutPipe()
    56  
    57  	// gocv
    58  	window = gocv.NewWindow("Tello")
    59  	net    *gocv.Net
    60  	green  = color.RGBA{0, 255, 0, 0}
    61  
    62  	// tracking
    63  	tracking                 = false
    64  	detected                 = false
    65  	detectSize               = false
    66  	distTolerance            = 0.05 * dist(0, 0, frameX, frameY)
    67  	refDistance              float64
    68  	left, top, right, bottom float64
    69  
    70  	// drone
    71  	drone      = tello.NewDriver("8890")
    72  	flightData *tello.FlightData
    73  
    74  	// joystick
    75  	joyAdaptor                   = joystick.NewAdaptor()
    76  	stick                        = joystick.NewDriver(joyAdaptor, "dualshock4")
    77  	leftX, leftY, rightX, rightY atomic.Value
    78  )
    79  
    80  func init() {
    81  	leftX.Store(float64(0.0))
    82  	leftY.Store(float64(0.0))
    83  	rightX.Store(float64(0.0))
    84  	rightY.Store(float64(0.0))
    85  
    86  	// process drone events in separate goroutine for concurrency
    87  	go func() {
    88  		handleJoystick()
    89  
    90  		if err := ffmpeg.Start(); err != nil {
    91  			fmt.Println(err)
    92  			return
    93  		}
    94  
    95  		drone.On(tello.FlightDataEvent, func(data interface{}) {
    96  			// TODO: protect flight data from race condition
    97  			flightData = data.(*tello.FlightData)
    98  		})
    99  
   100  		drone.On(tello.ConnectedEvent, func(data interface{}) {
   101  			fmt.Println("Connected")
   102  			drone.StartVideo()
   103  			drone.SetVideoEncoderRate(tello.VideoBitRateAuto)
   104  			drone.SetExposure(0)
   105  			gobot.Every(100*time.Millisecond, func() {
   106  				drone.StartVideo()
   107  			})
   108  		})
   109  
   110  		drone.On(tello.VideoFrameEvent, func(data interface{}) {
   111  			pkt := data.([]byte)
   112  			if _, err := ffmpegIn.Write(pkt); err != nil {
   113  				fmt.Println(err)
   114  			}
   115  		})
   116  
   117  		robot := gobot.NewRobot("tello",
   118  			[]gobot.Connection{joyAdaptor},
   119  			[]gobot.Device{drone, stick},
   120  		)
   121  
   122  		robot.Start()
   123  	}()
   124  }
   125  
   126  func main() {
   127  	if len(os.Args) < 5 {
   128  		fmt.Println("How to run:\ngo run facetracker.go [model] [config] ([backend] [device])")
   129  		return
   130  	}
   131  
   132  	model := os.Args[1]
   133  	config := os.Args[2]
   134  	backend := gocv.NetBackendDefault
   135  	if len(os.Args) > 3 {
   136  		backend = gocv.ParseNetBackend(os.Args[3])
   137  	}
   138  
   139  	target := gocv.NetTargetCPU
   140  	if len(os.Args) > 4 {
   141  		target = gocv.ParseNetTarget(os.Args[4])
   142  	}
   143  
   144  	n := gocv.ReadNet(model, config)
   145  	if n.Empty() {
   146  		fmt.Printf("Error reading network model from : %v %v\n", model, config)
   147  		return
   148  	}
   149  	net = &n
   150  	defer net.Close()
   151  	net.SetPreferableBackend(gocv.NetBackendType(backend))
   152  	net.SetPreferableTarget(gocv.NetTargetType(target))
   153  
   154  	for {
   155  		// get next frame from stream
   156  		buf := make([]byte, frameSize)
   157  		if _, err := io.ReadFull(ffmpegOut, buf); err != nil {
   158  			fmt.Println(err)
   159  			continue
   160  		}
   161  		img, _ := gocv.NewMatFromBytes(frameY, frameX, gocv.MatTypeCV8UC3, buf)
   162  		if img.Empty() {
   163  			continue
   164  		}
   165  
   166  		trackFace(&img)
   167  
   168  		window.IMShow(img)
   169  		if window.WaitKey(10) >= 0 {
   170  			break
   171  		}
   172  	}
   173  }
   174  
   175  func trackFace(frame *gocv.Mat) {
   176  	W := float64(frame.Cols())
   177  	H := float64(frame.Rows())
   178  
   179  	blob := gocv.BlobFromImage(*frame, 1.0, image.Pt(300, 300), gocv.NewScalar(104, 177, 123, 0), false, false)
   180  	defer blob.Close()
   181  
   182  	net.SetInput(blob, "data")
   183  
   184  	detBlob := net.Forward("detection_out")
   185  	defer detBlob.Close()
   186  
   187  	detections := gocv.GetBlobChannel(detBlob, 0, 0)
   188  	defer detections.Close()
   189  
   190  	for r := 0; r < detections.Rows(); r++ {
   191  		confidence := detections.GetFloatAt(r, 2)
   192  		if confidence < 0.5 {
   193  			continue
   194  		}
   195  
   196  		left = float64(detections.GetFloatAt(r, 3)) * W
   197  		top = float64(detections.GetFloatAt(r, 4)) * H
   198  		right = float64(detections.GetFloatAt(r, 5)) * W
   199  		bottom = float64(detections.GetFloatAt(r, 6)) * H
   200  
   201  		left = math.Min(math.Max(0.0, left), W-1.0)
   202  		right = math.Min(math.Max(0.0, right), W-1.0)
   203  		bottom = math.Min(math.Max(0.0, bottom), H-1.0)
   204  		top = math.Min(math.Max(0.0, top), H-1.0)
   205  
   206  		detected = true
   207  		rect := image.Rect(int(left), int(top), int(right), int(bottom))
   208  		gocv.Rectangle(frame, rect, green, 3)
   209  	}
   210  
   211  	if !tracking || !detected {
   212  		return
   213  	}
   214  
   215  	if detectSize {
   216  		detectSize = false
   217  		refDistance = dist(left, top, right, bottom)
   218  	}
   219  
   220  	distance := dist(left, top, right, bottom)
   221  
   222  	// x axis
   223  	switch {
   224  	case right < W/2:
   225  		drone.CounterClockwise(50)
   226  	case left > W/2:
   227  		drone.Clockwise(50)
   228  	default:
   229  		drone.Clockwise(0)
   230  	}
   231  
   232  	// y axis
   233  	switch {
   234  	case top < H/10:
   235  		drone.Up(25)
   236  	case bottom > H-H/10:
   237  		drone.Down(25)
   238  	default:
   239  		drone.Up(0)
   240  	}
   241  
   242  	// z axis
   243  	switch {
   244  	case distance < refDistance-distTolerance:
   245  		drone.Forward(20)
   246  	case distance > refDistance+distTolerance:
   247  		drone.Backward(20)
   248  	default:
   249  		drone.Forward(0)
   250  	}
   251  }
   252  
   253  func dist(x1, y1, x2, y2 float64) float64 {
   254  	return math.Sqrt((x2-x1)*(x2-x1) + (y2-y1)*(y2-y1))
   255  }
   256  
   257  func handleJoystick() {
   258  	stick.On(joystick.CirclePress, func(data interface{}) {
   259  		drone.Forward(0)
   260  		drone.Up(0)
   261  		drone.Clockwise(0)
   262  		tracking = !tracking
   263  		if tracking {
   264  			detectSize = true
   265  			println("tracking")
   266  		} else {
   267  			detectSize = false
   268  			println("not tracking")
   269  		}
   270  	})
   271  	stick.On(joystick.SquarePress, func(data interface{}) {
   272  		fmt.Println("battery:", flightData.BatteryPercentage)
   273  	})
   274  	stick.On(joystick.TrianglePress, func(data interface{}) {
   275  		drone.TakeOff()
   276  		println("Takeoff")
   277  	})
   278  	stick.On(joystick.XPress, func(data interface{}) {
   279  		drone.Land()
   280  		println("Land")
   281  	})
   282  	stick.On(joystick.LeftX, func(data interface{}) {
   283  		val := float64(data.(int16))
   284  		leftX.Store(val)
   285  	})
   286  
   287  	stick.On(joystick.LeftY, func(data interface{}) {
   288  		val := float64(data.(int16))
   289  		leftY.Store(val)
   290  	})
   291  
   292  	stick.On(joystick.RightX, func(data interface{}) {
   293  		val := float64(data.(int16))
   294  		rightX.Store(val)
   295  	})
   296  
   297  	stick.On(joystick.RightY, func(data interface{}) {
   298  		val := float64(data.(int16))
   299  		rightY.Store(val)
   300  	})
   301  	gobot.Every(50*time.Millisecond, func() {
   302  		rightStick := getRightStick()
   303  
   304  		switch {
   305  		case rightStick.y < -10:
   306  			drone.Forward(tello.ValidatePitch(rightStick.y, offset))
   307  		case rightStick.y > 10:
   308  			drone.Backward(tello.ValidatePitch(rightStick.y, offset))
   309  		default:
   310  			drone.Forward(0)
   311  		}
   312  
   313  		switch {
   314  		case rightStick.x > 10:
   315  			drone.Right(tello.ValidatePitch(rightStick.x, offset))
   316  		case rightStick.x < -10:
   317  			drone.Left(tello.ValidatePitch(rightStick.x, offset))
   318  		default:
   319  			drone.Right(0)
   320  		}
   321  	})
   322  
   323  	gobot.Every(50*time.Millisecond, func() {
   324  		leftStick := getLeftStick()
   325  		switch {
   326  		case leftStick.y < -10:
   327  			drone.Up(tello.ValidatePitch(leftStick.y, offset))
   328  		case leftStick.y > 10:
   329  			drone.Down(tello.ValidatePitch(leftStick.y, offset))
   330  		default:
   331  			drone.Up(0)
   332  		}
   333  
   334  		switch {
   335  		case leftStick.x > 20:
   336  			drone.Clockwise(tello.ValidatePitch(leftStick.x, offset))
   337  		case leftStick.x < -20:
   338  			drone.CounterClockwise(tello.ValidatePitch(leftStick.x, offset))
   339  		default:
   340  			drone.Clockwise(0)
   341  		}
   342  	})
   343  }
   344  
   345  func getLeftStick() pair {
   346  	s := pair{x: 0, y: 0}
   347  	s.x = leftX.Load().(float64)
   348  	s.y = leftY.Load().(float64)
   349  	return s
   350  }
   351  
   352  func getRightStick() pair {
   353  	s := pair{x: 0, y: 0}
   354  	s.x = rightX.Load().(float64)
   355  	s.y = rightY.Load().(float64)
   356  	return s
   357  }