Advertisement
RedSlimeballYT

FIR zero-cross W2M

Mar 21st, 2025 (edited)
532
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.92 KB | Software | 0 0
  1. import numpy as np
  2. import mido
  3. import argparse
  4. import os
  5. from scipy.signal import firwin, filtfilt
  6. from scipy.io.wavfile import read
  7. from tqdm import tqdm
  8.  
  9. class FrequencyBand:
  10.     def __init__(self, midi_note, sr):
  11.         self.note = midi_note
  12.         self.freq = 440 * (2 ** ((midi_note - 69) / 12))
  13.         self.sr = sr
  14.         self.filter = self.create_filter()
  15.        
  16.     def create_filter(self):
  17.         numtaps = 2049
  18.         cutoff = [self.freq * 0.95, self.freq * 1.05]
  19.         return firwin(numtaps, cutoff, fs=self.sr, pass_zero=False)
  20.  
  21. class AudioAnalyzer:
  22.     def __init__(self, file_path, target_sr=44100):
  23.         self.sr, self.audio = read(file_path)
  24.         if self.audio.ndim > 1:
  25.             self.audio = self.audio.mean(axis=1)
  26.         self.audio = self.audio.astype(np.float32)
  27.         self.audio /= np.max(np.abs(self.audio))
  28.        
  29.         if self.sr != target_sr:
  30.             self.resample(target_sr)
  31.             self.sr = target_sr
  32.            
  33.     def resample(self, target_sr):
  34.         ratio = target_sr / self.sr
  35.         self.audio = np.interp(
  36.             np.arange(0, len(self.audio), ratio),
  37.             np.arange(0, len(self.audio)),
  38.             self.audio
  39.         )
  40.  
  41. class MidiConverter:
  42.     def __init__(self, ppqn=22050, bpm=120*(88200/1920)):
  43.         self.ppqn = ppqn
  44.         self.bpm = bpm
  45.         self.ticks_per_sec = 44100
  46.        
  47.     def convert_events(self, all_events):
  48.         tracks = {note: mido.MidiTrack() for note in range(128)}
  49.         sorted_events = sorted(all_events, key=lambda x: x['time'])
  50.        
  51.         last_times = {note: 0 for note in range(128)}
  52.         for event in tqdm(sorted_events, desc="Placing notes"):
  53.             note = event['note']
  54.             track = tracks[note]
  55.            
  56.             ticks = int(event['time'] * self.ticks_per_sec)
  57.             delta = ticks - last_times[note]
  58.            
  59.             track.append(mido.Message(
  60.                 event['type'],
  61.                 note=note,
  62.                 velocity=event.get('velocity', 64),
  63.                 time=delta
  64.             ))
  65.             last_times[note] = ticks
  66.        
  67.         return [mido.MidiTrack()] + list(tracks.values())
  68.  
  69. def analyze_band(args):
  70.     band, audio = args
  71.     try:
  72.         filtered = filtfilt(band.filter, [1.0], audio)
  73.     except ValueError:
  74.         return []
  75.    
  76.     crossings = []
  77.     state = 0
  78.     HYSTERESIS = 0.01
  79.    
  80.     for i in tqdm(range(len(filtered)), desc=f"Note {band.note:03d}", leave=False):
  81.         current = filtered[i]
  82.         if state == 0 and current > HYSTERESIS:
  83.             state = 1
  84.             crossings.append(i)
  85.         elif state == 1 and current < -HYSTERESIS:
  86.             state = 0
  87.             crossings.append(i)
  88.    
  89.     events = []
  90.     for i in tqdm(range(2, len(crossings), 2), desc=f"Cycles {band.note:03d}", leave=False):
  91.         start = crossings[i-2]
  92.         end = crossings[i]
  93.        
  94.         if end - start < 2:
  95.             continue
  96.            
  97.         segment = filtered[start:end]
  98.         velocity = int(np.clip(np.max(np.abs(segment) ** 0.5) * 127, 1, 127))
  99.        
  100.         start_time = start / band.sr
  101.         end_time = end / band.sr
  102.  
  103.         wavelength = 1.0 / band.freq
  104.         max_duration = 1.5 * wavelength
  105.         actual_duration = end_time - start_time
  106.  
  107.         if actual_duration > max_duration:
  108.             end_time = start_time + wavelength
  109.  
  110.         events.append({'type': 'note_on', 'note': band.note, 'time': start_time, 'velocity': velocity})
  111.         events.append({'type': 'note_off', 'note': band.note, 'time': end_time, 'velocity': 0})
  112.    
  113.     return events
  114.  
  115. def colorize_midi(input_path, output_path, num_tracks=31):
  116.     mid = mido.MidiFile(input_path)
  117.     tracks = [[] for _ in range(num_tracks)]
  118.     active_notes = {}
  119.     tempo_messages = []
  120.  
  121.     for track in mid.tracks:
  122.         current_time = 0
  123.         for msg in track:
  124.             current_time += msg.time
  125.             if msg.type == 'set_tempo':
  126.                 tempo_messages.append((current_time, msg))
  127.             elif msg.type == 'note_on' and msg.velocity > 0:
  128.                 track_index = min(msg.velocity // (128 // num_tracks), num_tracks - 1)
  129.                 tracks[track_index].append(('on', msg.note, current_time, msg.velocity))
  130.                 active_notes[msg.note] = (track_index, current_time)
  131.             elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
  132.                 if msg.note in active_notes:
  133.                     track_index, start_time = active_notes[msg.note]
  134.                     tracks[track_index].append(('off', msg.note, current_time, 0))
  135.                     del active_notes[msg.note]
  136.  
  137.     new_mid = mido.MidiFile(ticks_per_beat=mid.ticks_per_beat)
  138.     master_track = mido.MidiTrack()
  139.    
  140.     prev_time = 0
  141.     for abs_time, tempo_msg in sorted(tempo_messages, key=lambda x: x[0]):
  142.         delta = abs_time - prev_time
  143.         tempo_msg.time = delta
  144.         master_track.append(tempo_msg)
  145.         prev_time = abs_time
  146.    
  147.     new_mid.tracks.append(master_track)
  148.  
  149.     for track in tracks:
  150.         if not track:
  151.             continue
  152.         sorted_track = sorted(track, key=lambda x: x[2])
  153.         midi_track = mido.MidiTrack()
  154.         prev_time = 0
  155.         for event in sorted_track:
  156.             delta = event[2] - prev_time
  157.             if event[0] == 'on':
  158.                 msg = mido.Message('note_on', note=event[1], velocity=event[3], time=delta)
  159.             else:
  160.                 msg = mido.Message('note_off', note=event[1], velocity=0, time=delta)
  161.             midi_track.append(msg)
  162.             prev_time = event[2]
  163.         new_mid.tracks.append(midi_track)
  164.  
  165.     new_mid.save(output_path)
  166.  
  167. def main(file_path, output_path):
  168.     audio_analyzer = AudioAnalyzer(file_path)
  169.     converter = MidiConverter()
  170.    
  171.     bands = [FrequencyBand(n, audio_analyzer.sr) for n in tqdm(range(128), desc="Creating bands")]
  172.    
  173.     all_events = []
  174.     for band in tqdm(bands, desc="Analyzing bands"):
  175.         events = analyze_band((band, audio_analyzer.audio))
  176.         all_events.extend(events)
  177.    
  178.     midi = mido.MidiFile(type=1)
  179.     midi.tracks = converter.convert_events(all_events)
  180.     midi.tracks[0].append(mido.MetaMessage('set_tempo', tempo=mido.bpm2tempo(converter.bpm)))
  181.    
  182.     base_path = os.path.splitext(output_path)[0]
  183.     intermediate_path = f"{base_path}_intermediate.mid"
  184.     colored_path = f"{base_path}-colored.mid"
  185.    
  186.     midi.save(intermediate_path)
  187.     colorize_midi(intermediate_path, colored_path)
  188.     os.remove(intermediate_path)
  189.     print(f"\nSuccessfully saved colored MIDI to {colored_path}")
  190.  
  191. if __name__ == "__main__":
  192.     parser = argparse.ArgumentParser(description='audio to midi converter with velocity corresponded to track colors')
  193.     parser.add_argument('input', help='Input WAV file')
  194.     parser.add_argument('output', help='Output MIDI file')
  195.     args = parser.parse_args()
  196.    
  197.     main(args.input, args.output)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement