Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ./scripts/split_large_csv.py import sys
- import random
- from pathlib import Path
- def count_lines(file_path):
- with open(file_path, 'r') as f:
- for i, _ in enumerate(f):
- pass
- return i # zero-indexed
- def split_large_csv(input_file, output_test_file, test_frac=0.2, seed=42):
- input_file = Path(input_file)
- output_test_file = Path(output_test_file)
- output_train_file = input_file # overwrite the input with the 80%
- print(f"Counting total rows in {input_file}...")
- total_rows = count_lines(input_file) # includes header
- print(f"Total lines (including header): {total_rows + 1}")
- num_data_rows = total_rows # exclude header
- num_test = int(num_data_rows * test_frac)
- print(f"Sampling {num_test} rows for test set...")
- random.seed(seed)
- test_indices = set(random.sample(range(num_data_rows), num_test))
- print(f"Splitting file into train and test...")
- with input_file.open("r") as f_in, \
- output_train_file.with_suffix(".tmp").open("w") as f_train, \
- output_test_file.open("w") as f_test:
- header = f_in.readline()
- f_train.write(header)
- f_test.write(header)
- for idx, line in enumerate(f_in):
- if idx in test_indices:
- f_test.write(line)
- else:
- f_train.write(line)
- print(f"Replacing original file with train split...")
- output_train_file.with_suffix(".tmp").replace(output_train_file)
- print("Done.")
- if __name__ == "__main__":
- if len(sys.argv) != 3:
- print("Usage: python split_large_csv.py <input_csv> <output_test_csv>")
- sys.exit(1)
- input_csv = sys.argv[1]
- output_test_csv = sys.argv[2]
- split_large_csv(input_csv, output_test_csv)
Advertisement
Add Comment
Please, Sign In to add comment