Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import psycopg2
- import bb_utils.ids
- import math
- import numba
- def get_neighbour_frames(frame_id, n_frames=None, seconds=None):
- seconds = seconds or (n_frames / 3 if n_frames else 5.0)
- with psycopg2.connect("dbname='beesbook' user='reader' host='localhost' password='reader'",
- application_name="get_neighbour_frames") as db:
- cursor = db.cursor()
- cursor.execute("SELECT index, fc_id, timestamp FROM plotter_frame WHERE frame_id = %s LIMIT 1", (frame_id,))
- f_index, fc_id, timestamp = cursor.fetchone()
- #print(f"Frame index: {f_index}, timestamp: {timestamp}, container ID: {fc_id}")
- cursor.execute("SELECT timestamp, frame_id, fc_id FROM plotter_frame WHERE timestamp >= %s AND timestamp <= %s", (timestamp - seconds, timestamp + seconds))
- results = list(cursor)
- containers = {fc_id for (_, _, fc_id) in results}
- cursor.execute("PREPARE fetch_container AS "
- "SELECT CAST(SUBSTR(video_name, 5, 1) AS INT) FROM plotter_framecontainer "
- "WHERE id = $1")
- cursor.execute("EXECUTE fetch_container (%s)", (fc_id,))
- target_cam = cursor.fetchone()[0]
- #print (f"Target cam: {target_cam}")
- matching_cam = set()
- for container in containers:
- cursor.execute("EXECUTE fetch_container (%s)", (container,))
- cam = cursor.fetchone()[0]
- #print (f"\tChecking cam: {cam}...")
- if cam == target_cam:
- matching_cam.add(container)
- results = [(timestamp, frame_id, target_cam) for (timestamp, frame_id, fc_id) in results if fc_id in matching_cam]
- return sorted(results)
- def get_neighbour_detections(frame_id, bee_id, verbose=False, **kwargs):
- if type(bee_id) != int:
- bee_id = bee_id.as_ferwar()
- neighbour_frames = get_neighbour_frames(frame_id, **kwargs)
- frame_ids = tuple([f[1] for f in neighbour_frames])
- with psycopg2.connect("dbname='beesbook' user='reader' host='localhost' password='reader'",
- application_name="get_neighbour_detections") as db:
- cursor = db.cursor()
- cursor.execute("SELECT timestamp, frame_id, x_pos, y_pos, orientation, track_id FROM bb_detections WHERE frame_id IN %s AND bee_id = %s ORDER BY timestamp ASC",
- (frame_ids, bee_id))
- detections = cursor.fetchall()
- results = []
- for n_idx, (timestamp, frame_id, _) in enumerate(neighbour_frames):
- if len(detections) == 0:
- results.append(None)
- continue
- if frame_id == detections[0][1]:
- if len(detections) == 1 or frame_id != detections[1][1]:
- results.append(detections[0])
- detections.pop(0)
- else:
- candidates = [d for d in detections if d[1] == frame_id]
- if verbose:
- print(f"Warning: more than one candidate! ({len(candidates)})")
- closest_candidate = None
- for i, r in reversed(list(enumerate(results))):
- if r is None:
- continue
- closest_candidate = r
- break
- candidate = None
- if closest_candidate is not None:
- for c in candidates:
- if c[-1] == closest_candidate[-1]: # track_id
- candidate = c
- break
- if verbose and candidate is not None:
- print("\t..resolved via track ID.")
- else:
- distances = np.array([[x, y] for (_, _, x, y, _, _) in candidates])
- if closest_candidate:
- distances -= np.array([closest_candidate[2], closest_candidate[3]])
- distances = np.linalg.norm(distances, axis=1)
- min_d = np.argmin(distances)
- candidate = candidates[min_d]
- if verbose:
- print("\t..resolved via distance.")
- results.append(candidate)#candidates[0])
- for i in range(len(candidates)):
- detections.pop(0)
- else:
- results.append(None)
- return results
- @numba.njit
- def short_angle_dist(a0,a1):
- max = math.pi*2
- da = (a1 - a0) % max
- return 2*da % max - da
- @numba.njit
- def angle_lerp(a0,a1,t):
- return a0 + short_angle_dist(a0,a1)*t
- def get_trajectory_around(frame_id, bee_id, **kwargs):
- detections = get_neighbour_detections(frame_id, bee_id, **kwargs)
- # (dt, frame_id, x, y, alpha)
- def unpack(d):
- if d is None:
- return [np.nan, np.nan, np.nan]
- (dt, frame_id, x, y, alpha, track_id) = d
- return [x, y, alpha]
- trajectory = np.array([unpack(d) for d in detections], dtype=np.float32)
- return trajectory
- @numba.njit(numba.float32[:](numba.float32[:, :]))
- def interpolate_trajectory(trajectory):
- nans = np.isnan(trajectory[:,0])
- not_nans = ~nans
- nans_idx = np.where(nans)[0]
- valid_idx = np.where(not_nans)[0]
- if len(valid_idx) < 2:
- return np.zeros(shape=(trajectory.shape[0]), dtype=np.float32)
- # Interpolate gaps.
- for i in nans_idx:
- # Find closest two points to use for interpolation.
- begin_t = np.searchsorted(valid_idx, i) - 1
- if begin_t == len(valid_idx) - 1:
- begin_t -= 1 # extrapolate right
- elif begin_t == -1:
- begin_t = 0 # extrapolate left
- begin_t_idx = valid_idx[begin_t]
- end_t_idx = valid_idx[begin_t + 1]
- last_t = trajectory[begin_t_idx]
- next_t = trajectory[end_t_idx]
- dx = (end_t_idx - begin_t_idx) / 3.0
- m = [(next_t[0] - last_t[0]) / dx,
- (next_t[1] - last_t[1]) / dx,
- short_angle_dist(last_t[0], next_t[0]) / dx]
- dt = (i - begin_t_idx) / 3.0
- e = [m[i] * dt + last_t[i] for i in range(3)]
- trajectory[i] = e
- return not_nans.astype(np.float32)
- def get_interpolated_trajectory(frame_id, bee_id, interpolate=True, **kwargs):
- trajectory = get_trajectory_around(frame_id, bee_id, **kwargs)
- if interpolate:
- interpolate_trajectory(trajectory)
- return trajectory
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement