Advertisement
Guest User

Tap Trainer

a guest
Mar 17th, 2023
46
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 27.94 KB | Source Code | 0 0
  1. import ctypes
  2. import datetime
  3. import json
  4. import pyaudio
  5. from pynput import keyboard
  6. import math
  7. import matplotlib.pyplot as pyplot
  8. from matplotlib.colors import LinearSegmentedColormap
  9. import numpy
  10. import queue
  11. import struct
  12. import time
  13. import uuid
  14.  
  15. # Enable support for ANSI escape codes to move the terminal cursor
  16. ctypes.windll.kernel32.SetConsoleMode(ctypes.windll.kernel32.GetStdHandle(-11), 7)
  17.  
  18. def get_focused_window_handle():
  19.   return ctypes.windll.user32.GetForegroundWindow()
  20.  
  21. log_filename = "tap_trainer_log.txt"
  22.  
  23. this_window_handle = get_focused_window_handle()
  24. keyboard_queue = queue.SimpleQueue()
  25. PRESS = 0
  26. RELEASE = 1
  27.  
  28. pressed_keys = set()
  29.  
  30. # Mimics fret order from low to high frets, used to detect if a "lower" fret is pressed before a "higher" fret
  31. key_rows = [
  32.   "poiuytrewq",
  33.   "lkjhgfdsa",
  34.   "mnbvcxz",
  35. ]
  36.  
  37. valid_characters = "".join(key_rows)
  38.  
  39. def get_next_keyboard_event(block = True, timeout = None):
  40.   try:
  41.     # Loop to avoid returning held keys
  42.     while True:
  43.       (event_type, key, seconds) = keyboard_queue.get(block, timeout)
  44.       if event_type == PRESS and key not in pressed_keys:
  45.         pressed_keys.add(key)
  46.         return (event_type, key, seconds)
  47.       elif event_type == RELEASE and key in pressed_keys:
  48.         pressed_keys.remove(key)
  49.         return (event_type, key, seconds)
  50.   except queue.Empty:
  51.     return None
  52.  
  53. def on_press(key):
  54.   if this_window_handle == get_focused_window_handle():
  55.     seconds = time.perf_counter()
  56.     keyboard_queue.put((PRESS, key, seconds))
  57.  
  58. def on_release(key):
  59.   if this_window_handle == get_focused_window_handle():
  60.     seconds = time.perf_counter()
  61.     keyboard_queue.put((RELEASE, key, seconds))
  62.  
  63. def drain_keyboard_queue():
  64.   try:
  65.     get_next_keyboard_event(False)
  66.   except queue.Empty:
  67.     return
  68.  
  69. def wait_for_keypress(expected_keys):
  70.   while True:
  71.     (event_type, key, seconds) = get_next_keyboard_event()
  72.     if event_type == PRESS and (expected_keys is None or key in expected_keys):
  73.       return key
  74.  
  75. def is_valid_sequence_character(key):
  76.   return isinstance(key, keyboard.KeyCode) and key.char is not None and key.char in valid_characters
  77.  
  78. def input_sequence():
  79.   sequence = []
  80.  
  81.   while True:
  82.     (event_type, key, seconds) = get_next_keyboard_event()
  83.     if event_type == PRESS:
  84.       if key is keyboard.Key.esc:
  85.         print("")
  86.         return None
  87.  
  88.       if key is keyboard.Key.space:
  89.         print("")
  90.         if len(sequence) > 0 and all(v == sequence[0] for v in sequence):
  91.           print("A sequence cannot consist of only a single key")
  92.           return None
  93.         return sequence
  94.  
  95.       if key is keyboard.Key.backspace or key is keyboard.Key.delete:
  96.         if len(sequence) > 0:
  97.           sequence.pop()
  98.           print("\b \b", end = "", flush = True)
  99.       elif is_valid_sequence_character(key):
  100.         sequence.append(key)
  101.         print(key.char, end = "", flush = True)
  102.  
  103. def detect_sequence(match_sequence):
  104.   # Try to determine what the sequence is - requires three perfect reps
  105.   required_repetitions_to_match = 3
  106.   min_sequence_length = 3
  107.   max_sequence_length = 8
  108.   for sequence_length in range(min_sequence_length, max_sequence_length + 1):
  109.     if len(match_sequence) >= sequence_length * required_repetitions_to_match:
  110.       match = True
  111.       for i in range(sequence_length, len(match_sequence)):
  112.         if match_sequence[i] != match_sequence[i % sequence_length]:
  113.           match = False
  114.           break
  115.  
  116.       if match:
  117.         return match_sequence[:sequence_length]
  118.  
  119.   if len(match_sequence) < max_sequence_length * required_repetitions_to_match:
  120.     return None # No match found yet, keep adding input
  121.   else:
  122.     return [] # Could not find a match
  123.  
  124. class MultiLineMessage:
  125.   def __init__(self):
  126.     self.message = ""
  127.  
  128.   def add(self, text):
  129.     self.message += text
  130.  
  131.   def add_line(self, line):
  132.     self.message += line + "\n"
  133.  
  134.   def clear_current_line(self):
  135.     self.message += "\033[K"
  136.  
  137.   def goto_previous_line(self, count = 1):
  138.     for i in range(count):
  139.       self.message += "\033[F"
  140.  
  141.   def print(self):
  142.     print(self.message, end = "", flush = True)
  143.     self.message = ""
  144.  
  145. HIGHLIGHT_BACKGROUND = 0
  146. HIGHLIGHT_TEXT = 1
  147.  
  148. class CaptureResult:
  149.   def __init__(self):
  150.     self.sequence = []
  151.     self.captured_length = 0
  152.     self.captured_duration = 0
  153.     self.notes_per_second = 0
  154.     self.timing_variance = 0
  155.     self.error_count = 0
  156.     self.captured_notes = []
  157.  
  158. class CaptureData:
  159.   def __init__(self):
  160.     # Array of (index in sequence, time)
  161.     self.notes = []
  162.  
  163.   def add_sample(self, sequence_index, note_time):
  164.     self.notes.append((sequence_index, note_time))
  165.  
  166.   def valid_note_count(self):
  167.     return len([v for v in self.notes if v[0] >= 0])
  168.  
  169.   def get_nps_and_timing_variance(self, recent_sample_count = None):
  170.     valid_note_times = [v[1] for v in self.notes if v[0] >= 0]
  171.     total_sample_count = len(valid_note_times)
  172.     if recent_sample_count is None:
  173.       recent_sample_count = total_sample_count
  174.     sample_count = min(recent_sample_count, total_sample_count)
  175.     start_index = total_sample_count - sample_count
  176.  
  177.     if sample_count < 2:
  178.         return None
  179.  
  180.     # Use running least squares to determine BPM, i.e. we will solve
  181.     # y = A + Bx + err
  182.     # where y is note time and x is note index (0, 1, 2, 3, ...)
  183.     # B = (N * (sum xy) - (sum x) * (sum y)) / (N * (sum x^2) - (sum x)^2)
  184.     # A = (sum y - B * sum x) / N
  185.     sum_x = 0.0
  186.     sum_x2 = 0.0
  187.     sum_y = 0.0
  188.     sum_xy = 0.0
  189.     for i in range(start_index, total_sample_count):
  190.       sum_x += i
  191.       sum_x2 += i * i
  192.       sum_y += valid_note_times[i]
  193.       sum_xy += i * valid_note_times[i]
  194.  
  195.     slope = (sample_count * sum_xy - sum_x * sum_y) / (sample_count * sum_x2 - sum_x * sum_x)
  196.     # intercept = (sum_y - intercept * sum_x) / sample_count # Not needed
  197.  
  198.     # Calculate average timing error
  199.     expected_seconds_per_note = slope
  200.     mean_abs_error_sum = 0
  201.     for i in range(start_index + 1, total_sample_count):
  202.       seconds_per_note = valid_note_times[i] - valid_note_times[i - 1]
  203.       mean_abs_error_sum += abs(seconds_per_note - expected_seconds_per_note)
  204.     mean_abs_error = mean_abs_error_sum / (sample_count - 1)
  205.  
  206.     notes_per_second = 1.0 / slope
  207.     timing_variance = mean_abs_error / expected_seconds_per_note
  208.  
  209.     return (notes_per_second, timing_variance)
  210.  
  211.   def analyze_timing(self, sequence_length, recent_sample_count, most_recent_only, highlight_sequence_index = None, highlight_style = HIGHLIGHT_BACKGROUND):
  212.     total_sample_count = len(self.notes)
  213.     if recent_sample_count is None:
  214.       recent_sample_count = total_sample_count
  215.     sample_count = min(recent_sample_count, total_sample_count)
  216.     start_index = total_sample_count - sample_count
  217.  
  218.     note_times = [0] * sequence_length
  219.     note_counts = [0] * sequence_length
  220.     # Iterate in reverse so most_recent_only works. Note that start_index is explicitly not included in this loop.
  221.     for i in range(start_index + sample_count - 1, start_index, -1):
  222.       sequence_index = self.notes[i - 1][0]
  223.       if sequence_index < 0:
  224.         continue
  225.       if not most_recent_only or note_counts[sequence_index] == 0:
  226.         note_time = self.notes[i][1] - self.notes[i - 1][1]
  227.         note_times[sequence_index] += note_time
  228.         note_counts[sequence_index] += 1
  229.  
  230.     if any(v == 0 for v in note_counts):
  231.       return "" # Not all notes have been captured
  232.  
  233.     for i in range(sequence_length):
  234.       note_times[i] /= note_counts[i]
  235.  
  236.     chart_width = 128
  237.     chart = ["|"] + ["_"] * chart_width + [" "] * 8 # Add padding to be safe against overflows
  238.  
  239.     total_note_times = sum(note_times)
  240.     current_note_offset = 0
  241.     previous_chart_position = 0
  242.     highlight_position = None
  243.     for i in range(sequence_length):
  244.       current_note_offset += note_times[i]
  245.       chart_position = round(current_note_offset * chart_width / total_note_times)
  246.       chart[chart_position] = "|"
  247.       note_time_percent = f"{100 * note_times[i] * sequence_length / total_note_times:.0f}%"
  248.       percent_chart_position = round((previous_chart_position + chart_position - len(note_time_percent)) / 2)
  249.       for t, c in enumerate(note_time_percent):
  250.         chart[percent_chart_position + t] = c
  251.  
  252.       if i == highlight_sequence_index:
  253.         highlight_position = (previous_chart_position, chart_position)
  254.  
  255.       previous_chart_position = chart_position
  256.  
  257.     if highlight_position is not None:
  258.       start_index, end_index = highlight_position
  259.       start_color_code = "\u001b[41m" if highlight_style == HIGHLIGHT_BACKGROUND else "\u001b[38;5;255m" # Red background
  260.       end_color_code = "\u001b[0m" if highlight_style == HIGHLIGHT_BACKGROUND else "\u001b[38;5;242m" # Reset
  261.       chart = chart[:end_index + 1] + list(end_color_code) + chart[end_index + 1:]
  262.       chart = chart[:start_index] + list(start_color_code) + chart[start_index:]
  263.       if highlight_style == HIGHLIGHT_TEXT:
  264.         chart = list("\u001b[38;5;242m") + chart + list("\u001b[0m")
  265.  
  266.     return "".join(chart)
  267.  
  268. def capture_sequence(sequence):
  269.   if sequence is None:
  270.     sequence_repeat_counts = None
  271.   else:
  272.     # Determine how many times any keys repeat. Note that a sequence should never all be the same key.
  273.     sequence_repeat_counts = [1] * len(sequence)
  274.     for i in range(len(sequence)):
  275.       for j in range(1, len(sequence) - 1):
  276.         if sequence[i] == sequence[(i + j) % len(sequence)]:
  277.           sequence_repeat_counts[i] += 1
  278.         else:
  279.           break
  280.  
  281.   match_sequence = []
  282.  
  283.   capture_pressed_keys = set()
  284.   next_sequence_index = 0
  285.   prev_sequence_index = None
  286.   start_time = None
  287.   capture_data = CaptureData()
  288.   error_count = 0
  289.   running_error_count = 0
  290.   last_key = None
  291.  
  292.   message = MultiLineMessage()
  293.   message.add_line("")
  294.   message.add_line("")
  295.   message.add_line("Capture active (...):")
  296.   message.print()
  297.  
  298.   previous_timing_chart_immediate_no_color = None
  299.   while True:
  300.     return_result = False
  301.  
  302.     # Grab events which all occur very close in time and only process the most recent one
  303.     data_to_process = None
  304.     pending_data = get_next_keyboard_event()
  305.     while True:
  306.       (event_type, key, seconds) = pending_data
  307.  
  308.       if event_type == PRESS:
  309.         if key not in capture_pressed_keys:
  310.           capture_pressed_keys.add(key)
  311.  
  312.         if key is keyboard.Key.esc:
  313.           # Discard this capture
  314.           print("")
  315.           print("Discarding capture")
  316.           return None
  317.  
  318.         if key is keyboard.Key.space:
  319.           if sequence is None or capture_data.valid_note_count() < 16:
  320.             print("")
  321.             print("Not enough samples, discarding capture")
  322.             return None
  323.           return_result = True
  324.           break
  325.  
  326.         if is_valid_sequence_character(key):
  327.           commit_pending_data = True
  328.  
  329.           # Make sure there is not a note on a "higher fret" that is already pressed
  330.           row = next(r for r in key_rows if key.char in r)
  331.           for index in range(row.index(key.char) + 1, len(row)):
  332.             key_to_check = keyboard.KeyCode.from_char(row[index])
  333.             if key_to_check in pressed_keys:
  334.               # This key is pressed
  335.               commit_pending_data = False
  336.               break
  337.  
  338.           if commit_pending_data:
  339.             data_to_process = (key, seconds)
  340.       else:
  341.         assert event_type == RELEASE
  342.  
  343.         # $TODO add release timing variance?
  344.         if key in capture_pressed_keys:
  345.           capture_pressed_keys.remove(key)
  346.  
  347.         if is_valid_sequence_character(key):
  348.           # Make sure there is not a note on a "higher fret" that is already pressed
  349.           blocked = False
  350.           row = next(r for r in key_rows if key.char in r)
  351.           for index in range(row.index(key.char) + 1, len(row)):
  352.             key_to_check = keyboard.KeyCode.from_char(row[index])
  353.             if key_to_check in pressed_keys:
  354.               # This key is pressed
  355.               blocked = True
  356.               break
  357.  
  358.           if not blocked:
  359.             # Determine if there is a note on a "lower fret" that is now pressed
  360.             for index in range(row.index(key.char) - 1, -1, -1):
  361.               key_to_check = keyboard.KeyCode.from_char(row[index])
  362.               if key_to_check in pressed_keys:
  363.                 # This key is pressed
  364.                 data_to_process = (key_to_check, seconds)
  365.                 break
  366.  
  367.       pending_data = get_next_keyboard_event(timeout = 0.005)
  368.       if pending_data is None:
  369.         break
  370.  
  371.     did_add_sample = False
  372.     if data_to_process is not None:
  373.       (key, seconds) = data_to_process
  374.       last_key = key
  375.  
  376.       if sequence is None:
  377.         if not is_valid_sequence_character(key):
  378.           print("")
  379.           print("Invalid sequence specified, discarding capture")
  380.           return None
  381.  
  382.         match_sequence.append(key)
  383.         sequence = detect_sequence(match_sequence)
  384.         if sequence is not None:
  385.           if len(sequence) == 0:
  386.             print("")
  387.             print("Unable to determine the sequence, discarding capture")
  388.             return None
  389.           else:
  390.             sequence_repeat_counts = [1] * len(sequence)
  391.             capture_data.notes = [(i % len(sequence), v[1]) for i, v in enumerate(capture_data.notes)]
  392.  
  393.       if sequence is None:
  394.         sequence_index = -1
  395.         repeat_count = 1
  396.       else:
  397.         repeat_count = 1 if prev_sequence_index is None else sequence_repeat_counts[prev_sequence_index]
  398.         expected_key = sequence[next_sequence_index]
  399.         if running_error_count > 0:
  400.           # An error was previously made. This could be one of three types of errors:
  401.           # (1) The wrong key was simply pressed
  402.           # (2) An extra key was pressed
  403.           # (3) A key was skipped
  404.           # We will try to handle all of these cases but the results will depend on the particular sequence.
  405.           if key == expected_key:
  406.             # The previous key was wrong but this one is right - we hit the wrong key once
  407.             sequence_index = next_sequence_index
  408.             running_error_count = 0
  409.           else:
  410.             next_expected_key = sequence[(next_sequence_index + 1) % len(sequence)]
  411.             previous_expected_key = sequence[(next_sequence_index - 1) % len(sequence)] # Python's % operator works for this
  412.  
  413.             # $TODO figure out repeat count recovery here?
  414.             if key == previous_expected_key:
  415.               # We pressed an extra key so we're 1 behind
  416.               sequence_index = (next_sequence_index - 1) % len(sequence)
  417.               prev_sequence_index = next_sequence_index
  418.               next_sequence_index = (next_sequence_index - 1) % len(sequence)
  419.               running_error_count = 0
  420.             elif key == next_expected_key:
  421.               # We skipped a key so we're 1 ahead
  422.               sequence_index = (next_sequence_index + 1) % len(sequence)
  423.               prev_sequence_index = next_sequence_index
  424.               next_sequence_index = (next_sequence_index + 1) % len(sequence)
  425.               running_error_count = 0
  426.             else:
  427.               sequence_index = -1
  428.               # We made another error that doesn't fall into one of the 3 cases that we handle
  429.               if running_error_count < 2:
  430.                 # If we make consecutive errors, don't keep penalizing - record up to 2 and then wait to get back in sync
  431.                 error_count += 1
  432.               running_error_count += 1
  433.         elif key == expected_key:
  434.           # Correct entry
  435.           sequence_index = next_sequence_index
  436.         else:
  437.           sequence_index = -1
  438.           error_count += 1
  439.           running_error_count += 1
  440.  
  441.         # Advance in the sequence unless we've made too many consecutive errors - then just wait to get back in sync
  442.         if running_error_count < 2:
  443.           prev_sequence_index = next_sequence_index
  444.           next_sequence_index = (next_sequence_index + sequence_repeat_counts[next_sequence_index]) % len(sequence)
  445.  
  446.       if start_time is None:
  447.         start_time = seconds
  448.       note_time = seconds - start_time
  449.  
  450.       # Record timing info for the key event (and any repeated previous events)
  451.       prev_note_time = capture_data.notes[-1][1] if len(capture_data.notes) > 0 else 0.0
  452.       for repeat_index in range(1, repeat_count + 1):
  453.         # Lerp across repeated notes, evenly spreading the total note duration
  454.         u = repeat_index / repeat_count
  455.         lerped_note_time = prev_note_time * (1 - u) + note_time * u
  456.         lerped_sequence_index = -1 if sequence is None else (sequence_index - repeat_count + repeat_index) % len(sequence)
  457.         capture_data.add_sample(lerped_sequence_index, lerped_note_time)
  458.         did_add_sample = True
  459.  
  460.     message.goto_previous_line(3)
  461.     message.clear_current_line()
  462.  
  463.     timing_chart_immediate = None
  464.     timing_chart_immediate_no_color = None
  465.     if sequence is not None and len(capture_data.notes) > 0:
  466.       last_note_sequence_index = capture_data.notes[-1][0]
  467.       highlight_sequence_index = None if return_result else last_note_sequence_index
  468.       timing_chart = capture_data.analyze_timing(len(sequence), recent_sample_count, False, highlight_sequence_index, HIGHLIGHT_BACKGROUND)
  469.       if did_add_sample:
  470.         timing_chart_immediate = capture_data.analyze_timing(len(sequence), recent_sample_count, True, highlight_sequence_index, HIGHLIGHT_BACKGROUND)
  471.         timing_chart_immediate_no_color = capture_data.analyze_timing(len(sequence), recent_sample_count, True, highlight_sequence_index, HIGHLIGHT_TEXT)
  472.     else:
  473.       timing_chart = ""
  474.  
  475.     # Add the latest timing data whenever there is an update so we get a continuous "stream" of timing charts
  476.     if timing_chart_immediate is not None:
  477.       # Overwrite the previous timing chart so that only the latest one has coloring
  478.       if previous_timing_chart_immediate_no_color is not None:
  479.         message.goto_previous_line()
  480.         message.add_line(f"  {previous_timing_chart_immediate_no_color}")
  481.  
  482.       previous_timing_chart_immediate_no_color = timing_chart_immediate_no_color
  483.       message.add_line(f"  {timing_chart_immediate}")
  484.     elif return_result and previous_timing_chart_immediate_no_color is not None:
  485.       # Rewrite the last line without background color
  486.       message.goto_previous_line()
  487.       message.add_line(f"  {previous_timing_chart_immediate_no_color}")
  488.  
  489.     message.clear_current_line()
  490.     message.add_line("")
  491.     message.clear_current_line()
  492.     message.add_line(f"  {timing_chart}")
  493.  
  494.     message.clear_current_line()
  495.     message.add("Capture active (...): " if sequence is None else f"Capture active ({''.join(x.char for x in sequence)}): ")
  496.  
  497.     recent_sample_count = None if return_result or sequence is None else len(sequence) * 8
  498.     note_count = len(capture_data.notes)
  499.     nps_and_timing_variance = capture_data.get_nps_and_timing_variance(recent_sample_count)
  500.     if nps_and_timing_variance is not None:
  501.       notes_per_second, timing_variance = nps_and_timing_variance
  502.       bpm_16ths = notes_per_second * (60 / 4)
  503.  
  504.       message.add(f"{notes_per_second:.1f}nps ({bpm_16ths:.0f}bpm 16ths), ")
  505.       message.add(f"{100 * timing_variance:.0f}% timing variance, ")
  506.  
  507.     if note_count >= 1:
  508.       message.add(f"{error_count} error(s) ({100 * error_count / note_count:.0f}% error rate) ")
  509.  
  510.     if last_key is not None:
  511.       message.add(f"{last_key.char} ")
  512.  
  513.     message.add_line(f"{'' if running_error_count == 0 else '(!)'}")
  514.  
  515.     # Backtrack so we can overwrite the previous message next time
  516.     message.print()
  517.  
  518.     if return_result:
  519.       assert note_count >= 2
  520.  
  521.       result = CaptureResult()
  522.       result.sequence = sequence
  523.       result.captured_length = len(capture_data.notes)
  524.       result.captured_duration = capture_data.notes[-1][1] - capture_data.notes[0][1]
  525.       result.notes_per_second = notes_per_second
  526.       result.timing_variance = timing_variance
  527.       result.error_count = error_count
  528.       result.captured_notes = capture_data.notes
  529.       return result
  530.  
  531. def format_duration(total_seconds):
  532.   rounded_seconds = math.floor(total_seconds)
  533.   seconds = rounded_seconds % 60
  534.   total_minutes = math.floor(rounded_seconds / 60)
  535.   minutes = total_minutes % 60
  536.   total_hours = math.floor(total_minutes / 60)
  537.   return f"{total_hours}:{minutes:02d}:{seconds:02d}"
  538.  
  539. def list_logged_sequences():
  540.   sequences = {}
  541.   with open(log_filename, "r") as log_file:
  542.     lines = log_file.readlines()
  543.     for line in lines:
  544.       entry = json.loads(line)
  545.       sequence = entry["sequence"]
  546.       value = sequences[sequence] if sequence in sequences else (0, 0.0)
  547.       count, duration = value
  548.       count += 1
  549.       if "captured_duration" in entry:
  550.         duration += entry["captured_duration"]
  551.       else:
  552.         duration += entry["captured_length"] / entry["notes_per_second"] # Handle old log entry
  553.       sequences[sequence] = (count, duration)
  554.  
  555.   sorted_sequences = sorted(sequences, key = lambda k: sequences[k][0], reverse = True)
  556.   total_duration = sum(v[1] for v in sequences.values())
  557.   print(f"  Total duration: {format_duration(total_duration)}")
  558.   for sequence in sorted_sequences:
  559.     count, duration = sequences[sequence]
  560.     print(f"  {sequence} - {count}, {format_duration(duration)}")
  561.   print("")
  562.  
  563. PLOT_TYPE_INDEX_BPM_TIMING_VARIANCE = 1
  564. PLOT_TYPE_INDEX_BPM_TIMING_VARIANCE_MS = 2
  565. PLOT_TYPE_INDEX_BPM_ERROR_RATE = 3
  566. PLOT_TYPE_DATE_BPM_TIMING_VARIANCE = 4
  567. PLOT_TYPE_DATE_BPM_TIMING_VARIANCE_MS = 5
  568. PLOT_TYPE_DATE_BPM_ERROR_RATE = 6
  569. PLOT_TYPE_BPM_TIMING_VARIANCE = 7
  570. PLOT_TYPE_BPM_TIMING_VARIANCE_MS = 8
  571. PLOT_TYPE_BPM_ERROR_RATE = 9
  572.  
  573. def plot_logged_sequence(sequence, plot_type):
  574.   sequence = "".join(x.char for x in sequence)
  575.   entries = []
  576.   with open(log_filename, "r") as log_file:
  577.     lines = log_file.readlines()
  578.     for line in lines:
  579.       entry = json.loads(line)
  580.       if entry["sequence"] == sequence:
  581.         entries.append(entry)
  582.  
  583.   if len(entries) == 0:
  584.     print("No logged entries found")
  585.     return
  586.  
  587.   entries.sort(key = lambda entry: datetime.datetime.strptime(entry["date"], "%Y-%m-%d %H:%M:%S.%f"))
  588.  
  589.   indices = []
  590.   dates = []
  591.   bpms_16ths = []
  592.   timing_variances = []
  593.   timing_variances_ms = []
  594.   error_rates = []
  595.   for i, entry in enumerate(entries):
  596.     indices.append(i)
  597.     date = datetime.datetime.strptime(entry["date"], "%Y-%m-%d %H:%M:%S.%f")
  598.     dates.append(date)
  599.  
  600.     notes_per_second = entry["notes_per_second"]
  601.     bpm_16ths = notes_per_second * (60 / 4)
  602.     bpms_16ths.append(bpm_16ths)
  603.  
  604.     timing_variance = entry["timing_variance"]
  605.     timing_variances.append(100 * timing_variance)
  606.  
  607.     note_length_ms = 1000.0 / notes_per_second
  608.     timing_variances_ms.append(timing_variance * note_length_ms)
  609.  
  610.     error_rates.append(100 * entry["error_count"] / entry["captured_length"])
  611.  
  612.   rotate_x_labels = False
  613.   fit = False
  614.   if plot_type == PLOT_TYPE_INDEX_BPM_TIMING_VARIANCE:
  615.     x = indices
  616.     y = bpms_16ths
  617.     colors = timing_variances
  618.   elif plot_type == PLOT_TYPE_INDEX_BPM_TIMING_VARIANCE_MS:
  619.     x = indices
  620.     y = bpms_16ths
  621.     colors = timing_variances_ms
  622.   elif plot_type == PLOT_TYPE_INDEX_BPM_ERROR_RATE:
  623.     x = indices
  624.     y = bpms_16ths
  625.     colors = error_rates
  626.   elif plot_type == PLOT_TYPE_DATE_BPM_TIMING_VARIANCE:
  627.     x = dates
  628.     y = bpms_16ths
  629.     colors = timing_variances
  630.     rotate_x_labels = True
  631.   elif plot_type == PLOT_TYPE_DATE_BPM_TIMING_VARIANCE_MS:
  632.     x = dates
  633.     y = bpms_16ths
  634.     colors = timing_variances_ms
  635.     rotate_x_labels = True
  636.   elif plot_type == PLOT_TYPE_DATE_BPM_ERROR_RATE:
  637.     x = dates
  638.     y = bpms_16ths
  639.     colors = error_rates
  640.     rotate_x_labels = True
  641.   elif plot_type == PLOT_TYPE_BPM_TIMING_VARIANCE:
  642.     x = bpms_16ths
  643.     y = timing_variances
  644.     colors = timing_variances
  645.     fit = True
  646.   elif plot_type == PLOT_TYPE_BPM_TIMING_VARIANCE_MS:
  647.     x = bpms_16ths
  648.     y = timing_variances_ms
  649.     colors = timing_variances_ms
  650.     fit = True
  651.   elif plot_type == PLOT_TYPE_BPM_ERROR_RATE:
  652.     x = bpms_16ths
  653.     y = error_rates
  654.     colors = error_rates
  655.     fit = True
  656.   else:
  657.     assert False
  658.  
  659.   if fit:
  660.     trend = numpy.polyfit(x, y, deg = 1)
  661.  
  662.   cmap = LinearSegmentedColormap.from_list("gor", ["g", "#ff8000", "r"], N = 256)
  663.   pyplot.scatter(x, y, c = colors, cmap = cmap)
  664.   if fit:
  665.     trendpoly = numpy.poly1d(trend)
  666.     pyplot.plot(x, trendpoly(x))
  667.   if rotate_x_labels:
  668.     pyplot.xticks(rotation = 60)
  669.   pyplot
  670.   pyplot.colorbar()
  671.   pyplot.tight_layout()
  672.   pyplot.show()
  673.  
  674. session_id = uuid.uuid4()
  675. with keyboard.Listener(on_press = on_press, on_release = on_release) as listener:
  676.   # Wait for a moment and drain the queue to clear out any initial events (e.g. the enter key release)
  677.   time.sleep(0.1)
  678.   drain_keyboard_queue()
  679.  
  680.   # Note: annoyingly, input into stdin builds up in the background. I'm not sure how to avoid this.
  681.   while True:
  682.     print('Press <space> to start, <L> to view logged sequences, and <P> to plot sequence history')
  683.     key = wait_for_keypress([keyboard.Key.space, keyboard.KeyCode.from_char('l'), keyboard.KeyCode.from_char('p')])
  684.     if key == keyboard.Key.space:
  685.       print("Enter sequence and press <space> (empty sequence will auto-detect on capture): ", end = "", flush = True)
  686.       sequence = input_sequence()
  687.       if sequence is None:
  688.         continue
  689.  
  690.       capture_result = capture_sequence(sequence if len(sequence) > 0 else None)
  691.       print("")
  692.       if capture_result is None:
  693.         continue
  694.  
  695.       # Write each log file line as its own JSON string for parsing convenience
  696.       result_json = json.dumps(
  697.         {
  698.           "session": str(session_id),
  699.           "date": str(datetime.datetime.now()),
  700.           "sequence": "".join(x.char for x in capture_result.sequence),
  701.           "captured_length": capture_result.captured_length,
  702.           "captured_duration": capture_result.captured_duration,
  703.           "notes_per_second": capture_result.notes_per_second,
  704.           "timing_variance": capture_result.timing_variance,
  705.           "error_count": capture_result.error_count,
  706.           "captured_notes": capture_result.captured_notes,
  707.         })
  708.  
  709.       with open(log_filename, "a") as log_file:
  710.         log_file.write(result_json + "\n")
  711.         log_file.flush()
  712.     elif key == keyboard.KeyCode.from_char("l"):
  713.       list_logged_sequences()
  714.     elif key == keyboard.KeyCode.from_char("p"):
  715.       print("Enter sequence and press <space>: ", end = "", flush = True)
  716.       sequence = input_sequence()
  717.       if sequence is None:
  718.         continue
  719.  
  720.       print("Enter plot type:")
  721.       print("  <1>: entry index - BPM (16ths) - timing variance %")
  722.       print("  <2>: entry index - BPM (16ths) - timing variance ms")
  723.       print("  <3>: entry index - BPM (16ths) - error rate %")
  724.       print("  <4>: entry date - BPM (16ths) - timing variance %")
  725.       print("  <5>: entry date - BPM (16ths) - timing variance ms")
  726.       print("  <6>: entry date - BPM (16ths) - error rate %")
  727.       print("  <7>: BPM (16ths) - timing variance %")
  728.       print("  <8>: BPM (16ths) - timing variance ms")
  729.       print("  <9>: BPM (16ths) - error rate")
  730.       plot_type_key = wait_for_keypress(None)
  731.       if plot_type_key == keyboard.Key.esc:
  732.         continue
  733.  
  734.       plot_types = {
  735.         keyboard.KeyCode.from_char("1"): PLOT_TYPE_INDEX_BPM_TIMING_VARIANCE,
  736.         keyboard.KeyCode.from_char("2"): PLOT_TYPE_INDEX_BPM_TIMING_VARIANCE_MS,
  737.         keyboard.KeyCode.from_char("3"): PLOT_TYPE_INDEX_BPM_ERROR_RATE,
  738.         keyboard.KeyCode.from_char("4"): PLOT_TYPE_DATE_BPM_TIMING_VARIANCE,
  739.         keyboard.KeyCode.from_char("5"): PLOT_TYPE_DATE_BPM_TIMING_VARIANCE_MS,
  740.         keyboard.KeyCode.from_char("6"): PLOT_TYPE_DATE_BPM_ERROR_RATE,
  741.         keyboard.KeyCode.from_char("7"): PLOT_TYPE_BPM_TIMING_VARIANCE,
  742.         keyboard.KeyCode.from_char("8"): PLOT_TYPE_BPM_TIMING_VARIANCE_MS,
  743.         keyboard.KeyCode.from_char("9"): PLOT_TYPE_BPM_ERROR_RATE,
  744.       }
  745.  
  746.       if plot_type_key not in plot_types:
  747.         print("Invalid plot type")
  748.         continue
  749.  
  750.       plot_logged_sequence(sequence, plot_types[plot_type_key])
  751.       print("")
  752.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement