kalman.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #!/usr/bin/python
  2. """
  3. Tracking of rotating point.
  4. Rotation speed is constant.
  5. Both state and measurements vectors are 1D (a point angle),
  6. Measurement is the real point angle + gaussian noise.
  7. The real and the estimated points are connected with yellow line segment,
  8. the real and the measured points are connected with red line segment.
  9. (if Kalman filter works correctly,
  10. the yellow segment should be shorter than the red one).
  11. Pressing any key (except ESC) will reset the tracking with a different speed.
  12. Pressing ESC will stop the program.
  13. """
  14. # Python 2/3 compatibility
  15. import sys
  16. PY3 = sys.version_info[0] == 3
  17. if PY3:
  18. long = int
  19. import cv2
  20. from math import cos, sin
  21. import numpy as np
  22. if __name__ == "__main__":
  23. img_height = 500
  24. img_width = 500
  25. kalman = cv2.KalmanFilter(2, 1, 0)
  26. code = long(-1)
  27. cv2.namedWindow("Kalman")
  28. while True:
  29. state = 0.1 * np.random.randn(2, 1)
  30. kalman.transitionMatrix = np.array([[1., 1.], [0., 1.]])
  31. kalman.measurementMatrix = 1. * np.ones((1, 2))
  32. kalman.processNoiseCov = 1e-5 * np.eye(2)
  33. kalman.measurementNoiseCov = 1e-1 * np.ones((1, 1))
  34. kalman.errorCovPost = 1. * np.ones((2, 2))
  35. kalman.statePost = 0.1 * np.random.randn(2, 1)
  36. while True:
  37. def calc_point(angle):
  38. return (np.around(img_width/2 + img_width/3*cos(angle), 0).astype(int),
  39. np.around(img_height/2 - img_width/3*sin(angle), 1).astype(int))
  40. state_angle = state[0, 0]
  41. state_pt = calc_point(state_angle)
  42. prediction = kalman.predict()
  43. predict_angle = prediction[0, 0]
  44. predict_pt = calc_point(predict_angle)
  45. measurement = kalman.measurementNoiseCov * np.random.randn(1, 1)
  46. # generate measurement
  47. measurement = np.dot(kalman.measurementMatrix, state) + measurement
  48. measurement_angle = measurement[0, 0]
  49. measurement_pt = calc_point(measurement_angle)
  50. # plot points
  51. def draw_cross(center, color, d):
  52. cv2.line(img,
  53. (center[0] - d, center[1] - d), (center[0] + d, center[1] + d),
  54. color, 1, cv2.LINE_AA, 0)
  55. cv2.line(img,
  56. (center[0] + d, center[1] - d), (center[0] - d, center[1] + d),
  57. color, 1, cv2.LINE_AA, 0)
  58. img = np.zeros((img_height, img_width, 3), np.uint8)
  59. draw_cross(np.int32(state_pt), (255, 255, 255), 3)
  60. draw_cross(np.int32(measurement_pt), (0, 0, 255), 3)
  61. draw_cross(np.int32(predict_pt), (0, 255, 0), 3)
  62. cv2.line(img, state_pt, measurement_pt, (0, 0, 255), 3, cv2.LINE_AA, 0)
  63. cv2.line(img, state_pt, predict_pt, (0, 255, 255), 3, cv2.LINE_AA, 0)
  64. kalman.correct(measurement)
  65. process_noise = kalman.processNoiseCov * np.random.randn(2, 1)
  66. state = np.dot(kalman.transitionMatrix, state) + process_noise
  67. cv2.imshow("Kalman", img)
  68. code = cv2.waitKey(100) % 0x100
  69. if code != -1:
  70. break
  71. if code in [27, ord('q'), ord('Q')]:
  72. break
  73. cv2.destroyWindow("Kalman")