Advertisement
alkkofficial

Rust binary output

Oct 7th, 2023
171
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 12.33 KB | None | 0 0
  1. use bytemuck::cast_slice;
  2. use cozy_chess::{Color, GameStatus};
  3. use serde::Serialize;
  4. use std::{
  5.     cmp::{max, min},
  6.     fs::File,
  7.     io::{self, BufWriter, Seek, Write},
  8.     path::{Path, PathBuf},
  9. };
  10.  
  11. use crate::{
  12.     boardmanager::BoardStack,
  13.     dataformat::{Position, Simulation},
  14.     mvs::get_contents, decoder::convert_board,
  15. };
  16.  
  17. #[derive(Serialize)]
  18. struct MetaData<'a> {
  19.    game: &'a str,
  20.  
  21.     input_bool_shape: &'a [usize],
  22.    input_scalar_count: usize,
  23.    policy_shape: &'a [usize],
  24.  
  25.     game_count: usize,
  26.     position_count: usize,
  27.     includes_terminal_positions: bool,
  28.     includes_game_start_indices: bool,
  29.  
  30.     max_game_length: i32,
  31.     min_game_length: i32,
  32.     root_wdl: [f32; 3],
  33.     hit_move_limit: f32,
  34.  
  35.     scalar_names: &'static [&'static str],
  36. }
  37.  
  38. #[derive(Debug)]
  39. pub struct BinaryOutput {
  40.     game: String,
  41.     path: PathBuf,
  42.  
  43.     bin_write: BufWriter<File>,
  44.     off_write: BufWriter<File>,
  45.     json_tmp_write: BufWriter<File>,
  46.  
  47.     game_count: usize,
  48.     position_count: usize,
  49.  
  50.     max_game_length: Option<i32>,
  51.     min_game_length: Option<i32>,
  52.  
  53.     total_root_wdl: [u64; 3],
  54.     hit_move_limit_count: u64,
  55.  
  56.     next_offset: u64,
  57.     game_start_indices: Vec<u64>,
  58.  
  59.     finished: bool,
  60. }
  61.  
  62. #[derive(Debug)]
  63. struct Scalars {
  64.     game_id: usize,
  65.     pos_index: usize,
  66.     game_length: usize,
  67.     zero_visits: u64,
  68.     is_full_search: bool,
  69.     is_final_position: bool,
  70.     is_terminal: bool,
  71.     hit_move_limit: bool,
  72.     available_mv_count: usize,
  73.     played_mv: isize,
  74.     kdl_policy: f32,
  75.     final_values: f32, // z
  76.     zero_values: f32,  // q
  77.     net_values: f32,   // v
  78. }
  79.  
  80. impl BinaryOutput {
  81.     pub fn new(path: impl AsRef<Path>, game: &str) -> io::Result<Self> {
  82.         let path = path.as_ref().to_path_buf();
  83.         assert!(
  84.             path.extension().is_none(),
  85.             "Binary output path should not have an extension, .bin and .json are added automatically"
  86.         );
  87.  
  88.         //TODO try buffer sizes again
  89.         let bin_write = BufWriter::new(File::create(path.with_extension("bin"))?);
  90.         let off_write = BufWriter::new(File::create(path.with_extension("off"))?);
  91.         let json_tmp_write = BufWriter::new(File::create(path.with_extension("json.tmp"))?);
  92.  
  93.         Ok(BinaryOutput {
  94.             game: game.to_string(),
  95.  
  96.             bin_write,
  97.             off_write,
  98.             json_tmp_write,
  99.  
  100.             path,
  101.  
  102.             game_count: 0,
  103.             position_count: 0,
  104.  
  105.             max_game_length: None,
  106.             min_game_length: None,
  107.  
  108.             total_root_wdl: [0, 0, 0],
  109.             hit_move_limit_count: 0,
  110.  
  111.             next_offset: 0,
  112.             game_start_indices: vec![],
  113.  
  114.             finished: false,
  115.         })
  116.     }
  117.  
  118.     pub fn append(&mut self, simulation: &Simulation) -> io::Result<()> {
  119.         let Simulation {
  120.             positions,
  121.             final_board,
  122.         } = simulation;
  123.  
  124.         // collect metadata statistics
  125.         let game_id = self.game_count;
  126.         let game_length = positions.len();
  127.  
  128.         self.game_start_indices.push(self.position_count as u64);
  129.  
  130.         self.game_count += 1;
  131.         self.position_count += 1 + game_length;
  132.  
  133.         self.max_game_length = Some(max(game_length as i32, self.max_game_length.unwrap_or(-1)));
  134.         self.min_game_length = Some(min(
  135.             game_length as i32,
  136.             self.min_game_length.unwrap_or(i32::MAX),
  137.         ));
  138.  
  139.         // let outcome = final_board.outcome().unwrap_or(Outcome::Draw);
  140.  
  141.         let outcome: Option<Color> = match final_board.status() {
  142.             GameStatus::Drawn => None,
  143.             GameStatus::Won => Some(!final_board.board().side_to_move()),
  144.             GameStatus::Ongoing => panic!("Game is still ongoing!"),
  145.         };
  146.  
  147.         // write the positions
  148.         for (pos_index, position) in positions.iter().enumerate() {
  149.             let &Position {
  150.                 ref board,
  151.                 is_full_search,
  152.                 played_mv,
  153.                 zero_visits,
  154.                 ref zero_evaluation,
  155.                 ref net_evaluation,
  156.             } = position;
  157.  
  158.             let (available_mv_count, policy_indices) = collect_policy_indices(board);
  159.             assert_eq!(available_mv_count, zero_evaluation.policy.len());
  160.             assert_eq!(available_mv_count, net_evaluation.policy.len());
  161.  
  162.             // get all idx from mvs.rs
  163.             let all_idx = get_contents();
  164.             let played_mv_index = all_idx.iter().position(|&x| x == played_mv).unwrap();
  165.  
  166.             // let played_mv_index = self.mapper.move_to_index(board, played_mv);
  167.             // let kdl_policy = kdl_divergence(&zero_evaluation.policy, &net_evaluation.policy);
  168.             let kdl_policy = f32::NAN;
  169.             let moves_left = game_length + 1 - pos_index;
  170.             let stored_policy = &zero_evaluation.policy;
  171.             let final_values = match outcome {
  172.                 Some(Color::White) => net_evaluation.values,
  173.                 Some(Color::Black) => -net_evaluation.values,
  174.                 None => 0.0,
  175.             };
  176.             let scalars = Scalars {
  177.                 game_id,
  178.                 pos_index,
  179.                 game_length,
  180.                 zero_visits,
  181.                 is_full_search,
  182.                 is_final_position: false,
  183.                 is_terminal: false,
  184.                 hit_move_limit: false,
  185.                 available_mv_count: stored_policy.len(),
  186.                 played_mv: played_mv_index as isize,
  187.                 kdl_policy,
  188.                 final_values,
  189.                 zero_values: zero_evaluation.values,
  190.                 net_values: net_evaluation.values,
  191.             };
  192.  
  193.             self.append_position(board, &scalars, &policy_indices, stored_policy)?;
  194.         }
  195.         let final_values = match outcome {
  196.             Some(Color::White) => 1.0,
  197.             Some(Color::Black) => -1.0,
  198.             None => 0.0,
  199.         };
  200.         let scalars = Scalars {
  201.             game_id,
  202.             pos_index: game_length,
  203.             game_length,
  204.             zero_visits: 0,
  205.             is_full_search: false,
  206.             is_final_position: true,
  207.             is_terminal: final_board.is_terminal(),
  208.             hit_move_limit: !final_board.is_terminal(),
  209.             available_mv_count: 0,
  210.             played_mv: -1,
  211.             kdl_policy: f32::NAN,
  212.             final_values,
  213.             zero_values: f32::NAN,
  214.             //TODO in theory we could ask the network, but this is only really meaningful for muzero
  215.             net_values: f32::NAN,
  216.         };
  217.  
  218.         self.append_position(&final_board, &scalars, &[], &[])?;
  219.  
  220.         Ok(())
  221.     }
  222.  
  223.     fn append_position(
  224.         &mut self,
  225.         board: &BoardStack,
  226.         scalars: &Scalars,
  227.         policy_indices: &[u32],
  228.         policy_values: &[f32],
  229.     ) -> io::Result<()> {
  230.         // encode board
  231.         let board_bools: Vec<f32> = Vec::try_from(convert_board(board)).unwrap()[8*64-1..].to_vec();
  232.         let board_scalars: Vec<f32> = Vec::try_from(convert_board(board)).unwrap()[0..8*64].to_vec();
  233.         // assert_eq!(self.mapper.input_bool_len(), board_bools.len());
  234.         // assert_eq!(self.mapper.input_scalar_count(), board_scalars.len());
  235.         // assert_eq!(
  236.         //     (self.mapper.input_bool_len() + 7) / 8,
  237.         //     board_bools.storage().len()
  238.         // );
  239.  
  240.         // check that everything makes sense
  241.         let policy_len = policy_indices.len();
  242.         assert_eq!(policy_len, policy_values.len());
  243.         // assert_normalized_or_nan(scalars.zero_values.wdl.sum());
  244.         // assert_normalized_or_nan(scalars.net_values.wdl.sum());
  245.         // assert_normalized_or_nan(scalars.final_values.wdl.sum());
  246.         if policy_len != 0 {
  247.             assert_normalized_or_nan(policy_values.iter().sum());
  248.         }
  249.  
  250.         // save current offset
  251.         // we keep track of the offset ourselves because seeking/stream_position flushes the buffer and is slow
  252.         debug_assert_eq!(self.next_offset, self.bin_write.stream_position()?);
  253.         self.off_write.write_all(&self.next_offset.to_le_bytes())?;
  254.  
  255.         // actually write stuff to the bin file
  256.         let scalars = scalars.to_vec();
  257.         let data_to_write: &[&[u8]] = &[
  258.             cast_slice(&scalars),
  259.             cast_slice(&board_bools),
  260.             cast_slice(&board_scalars),
  261.             cast_slice(policy_indices),
  262.             cast_slice(policy_values),
  263.         ];
  264.         for &data in data_to_write {
  265.             self.bin_write.write_all(data)?;
  266.             self.next_offset += data.len() as u64;
  267.         }
  268.  
  269.         Ok(())
  270.     }
  271.  
  272.     pub fn finish(&mut self) -> io::Result<()> {
  273.         if self.finished {
  274.             panic!("This output is already finished")
  275.         }
  276.         self.finished = true;
  277.  
  278.         let meta = MetaData {
  279.             game: &self.game,
  280.             scalar_names: Scalars::NAMES,
  281.             input_bool_shape: &[13, 8, 8], // pieces + EP
  282.             input_scalar_count: 8, // move turn, counters, castling
  283.             policy_shape: &[1880 as usize],
  284.             game_count: self.game_count,
  285.             position_count: self.position_count,
  286.             includes_terminal_positions: true,
  287.             includes_game_start_indices: true,
  288.             max_game_length: self.max_game_length.unwrap_or(-1),
  289.             min_game_length: self.min_game_length.unwrap_or(-1),
  290.             // root_wdl: (self.total_root_wdl.cast::<f32>() / self.game_count as f32).to_slice(),
  291.             root_wdl: [0.0, 0.0, 0.0],
  292.             hit_move_limit: self.hit_move_limit_count as f32 / self.game_count as f32,
  293.         };
  294.  
  295.         serde_json::to_writer_pretty(&mut self.json_tmp_write, &meta)?;
  296.         self.off_write
  297.             .write_all(cast_slice(&self.game_start_indices))?;
  298.  
  299.         self.json_tmp_write.flush()?;
  300.         self.bin_write.flush()?;
  301.         self.off_write.flush()?;
  302.  
  303.         let path_json_tmp = self.path.with_extension("json.tmp");
  304.         let path_json = self.path.with_extension("json");
  305.         std::fs::rename(path_json_tmp, path_json)?;
  306.  
  307.         Ok(())
  308.     }
  309.  
  310.     pub fn game_count(&self) -> usize {
  311.         self.game_count
  312.     }
  313. }
  314.  
  315. fn collect_policy_indices(board: &BoardStack) -> (usize, Vec<u32>) {
  316.     match board.status() {
  317.         GameStatus::Ongoing => {
  318.             let mut policy_indices: Vec<u32> = vec![];
  319.  
  320.             let mut move_list = Vec::new();
  321.             board.board().generate_moves(|moves| {
  322.                 // Unpack dense move set into move list
  323.                 move_list.extend(moves);
  324.                 false
  325.             });
  326.  
  327.             for m in move_list.clone() {
  328.                 policy_indices.push(move_list.iter().position(|&x| x == m).unwrap() as u32);
  329.             }
  330.  
  331.             (policy_indices.len(), policy_indices)
  332.         }
  333.         GameStatus::Drawn | GameStatus::Won => (0, vec![]),
  334.     }
  335. }
  336.  
  337. fn assert_normalized_or_nan(x: f32) {
  338.     assert!(x.is_nan() || (1.0 - x).abs() < 0.001);
  339. }
  340.  
  341. impl Scalars {
  342.     const NAMES: &'static [&'static str] = &[
  343.         "game_id",
  344.         "pos_index",
  345.         "game_length",
  346.         "zero_visits",
  347.         "is_full_search",
  348.         "is_final_position",
  349.         "is_terminal",
  350.         "hit_move_limit",
  351.         "available_mv_count",
  352.         "played_mv",
  353.         "kdl_policy",
  354.         "final_v",
  355.         "final_wdl_w",
  356.         "final_wdl_d",
  357.         "final_wdl_l",
  358.         "final_moves_left",
  359.         "zero_v",
  360.         "zero_wdl_w",
  361.         "zero_wdl_d",
  362.         "zero_wdl_l",
  363.         "zero_moves_left",
  364.         "net_v",
  365.         "net_wdl_w",
  366.         "net_wdl_d",
  367.         "net_wdl_l",
  368.         "net_moves_left",
  369.     ];
  370.  
  371.     fn to_vec(&self) -> Vec<f32> {
  372.         let mut result = vec![
  373.             self.game_id as f32,
  374.             self.pos_index as f32,
  375.             self.game_length as f32,
  376.             self.zero_visits as f32,
  377.             self.is_full_search as u8 as f32,
  378.             self.is_final_position as u8 as f32,
  379.             self.is_terminal as u8 as f32,
  380.             self.hit_move_limit as u8 as f32,
  381.             self.available_mv_count as f32,
  382.             self.played_mv as f32,
  383.             self.kdl_policy as f32,
  384.         ];
  385.  
  386.         result.extend_from_slice(&[self.final_values]);
  387.         result.extend_from_slice(&[self.zero_values]);
  388.         result.extend_from_slice(&[self.net_values]);
  389.  
  390.         assert_eq!(result.len(), Self::NAMES.len());
  391.         result
  392.     }
  393. }
  394.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement