Shrooms

BIDIRECTIONAL TIME MACHINE

Dec 9th, 2025 (edited)
8
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.34 KB | None | 0 0
  1. #include <iostream>
  2. #include <fstream>
  3. #include <sstream>
  4. #include <unordered_map>
  5. #include <map>
  6. #include <vector>
  7. #include <string>
  8. #include <algorithm>
  9. #include <random>
  10. #include <cctype>
  11. #include <stdexcept>
  12. #include <limits>
  13. #include <cmath> // for log, exp
  14.  
  15. // ANDREW JUSTIN SOLESA
  16.  
  17. // =======================
  18. // Utility: String Helpers
  19. // =======================
  20.  
  21. std::string toLower(const std::string& s) {
  22. std::string out = s;
  23. std::transform(out.begin(), out.end(), out.begin(),
  24. [](unsigned char c) { return std::tolower(c); });
  25. return out;
  26. }
  27.  
  28. // Strip common punctuation at start/end of a token
  29. std::string stripPunct(const std::string& s) {
  30. size_t start = 0;
  31. size_t end = s.size();
  32.  
  33. while (start < end && std::ispunct(static_cast<unsigned char>(s[start])) &&
  34. s[start] != '\'' && s[start] != '-') {
  35. ++start;
  36. }
  37. while (end > start && std::ispunct(static_cast<unsigned char>(s[end - 1])) &&
  38. s[end - 1] != '\'' && s[end - 1] != '-') {
  39. --end;
  40. }
  41. if (start >= end) return "";
  42. return s.substr(start, end - start);
  43. }
  44.  
  45. std::vector<std::string> tokenize(const std::string& text) {
  46. std::istringstream iss(text);
  47. std::vector<std::string> tokens;
  48. std::string token;
  49. while (iss >> token) {
  50. token = stripPunct(token);
  51. if (!token.empty()) {
  52. tokens.push_back(token);
  53. }
  54. }
  55. return tokens;
  56. }
  57.  
  58. std::string joinTokens(const std::vector<std::string>& v) {
  59. std::string out;
  60. for (size_t i = 0; i < v.size(); ++i) {
  61. if (i > 0) out += " ";
  62. out += v[i];
  63. }
  64. return out;
  65. }
  66.  
  67. // Reverse the token order in a text
  68. std::string reverseTextTokens(const std::string& text) {
  69. auto tokens = tokenize(text);
  70. std::reverse(tokens.begin(), tokens.end());
  71. return joinTokens(tokens);
  72. }
  73.  
  74. // =======================
  75. // Trigram Language Model
  76. // (forward or backward)
  77. // =======================
  78.  
  79. class TrigramLanguageModel {
  80. public:
  81. void train(const std::string& text) {
  82. auto tokens = tokenize(text);
  83. if (tokens.empty()) return;
  84.  
  85. const std::string BOS = "<BOS>";
  86. std::vector<std::string> seq;
  87. seq.reserve(tokens.size() + 1);
  88. seq.push_back(BOS);
  89. for (auto& t : tokens) {
  90. seq.push_back(normalize(t));
  91. }
  92.  
  93. if (seq.size() < 2) return;
  94.  
  95. for (size_t i = 0; i < seq.size(); ++i) {
  96. const std::string& w = seq[i];
  97. if (w.empty()) continue;
  98. unigramCounts_[w] += 1;
  99. totalUnigrams_ += 1;
  100.  
  101. if (i + 1 < seq.size()) {
  102. const std::string& w2 = seq[i + 1];
  103. if (!w2.empty()) {
  104. bigramCounts_[w][w2] += 1;
  105. }
  106. }
  107.  
  108. if (i + 2 < seq.size()) {
  109. const std::string& w2 = seq[i + 1];
  110. const std::string& w3 = seq[i + 2];
  111. if (!w2.empty() && !w3.empty()) {
  112. std::string key = trigramKey(w, w2);
  113. trigramCounts_[key][w3] += 1;
  114. }
  115. }
  116. }
  117. trained_ = true;
  118. }
  119.  
  120. bool isTrained() const { return trained_; }
  121.  
  122. // ========= Prediction (for generation) =========
  123.  
  124. // Deterministic prediction with backoff: trigram -> bigram -> unigram
  125. std::string predictNextGreedy(const std::string& w1,
  126. const std::string& w2) const {
  127. std::string key = trigramKey(normalize(w1), normalize(w2));
  128. std::string next = bestFromMap(trigramCounts_, key);
  129. if (!next.empty()) return next;
  130.  
  131. // backoff to bigram using w2 as context
  132. next = bestFromMap(bigramCounts_, normalize(w2));
  133. if (!next.empty()) return next;
  134.  
  135. // backoff to most common unigram
  136. return bestUnigram();
  137. }
  138.  
  139. // Random, weighted by counts with backoff
  140. std::string predictNextRandom(const std::string& w1,
  141. const std::string& w2,
  142. std::mt19937& rng) const {
  143. std::string key = trigramKey(normalize(w1), normalize(w2));
  144. std::string next = randomFromMap(trigramCounts_, key, rng);
  145. if (!next.empty()) return next;
  146.  
  147. next = randomFromMap(bigramCounts_, normalize(w2), rng);
  148. if (!next.empty()) return next;
  149.  
  150. return randomUnigram(rng);
  151. }
  152.  
  153. // Generate deterministically (forward direction)
  154. std::string generateDeterministic(const std::string& prompt,
  155. size_t numWords) const {
  156. if (!trained_) return "[ERROR: model not trained]";
  157. if (numWords == 0) return prompt;
  158.  
  159. auto promptTokens = tokenize(prompt);
  160. std::string result = joinTokens(promptTokens);
  161.  
  162. // choose context
  163. std::string w1, w2;
  164. if (promptTokens.size() >= 2) {
  165. w1 = normalize(promptTokens[promptTokens.size() - 2]);
  166. w2 = normalize(promptTokens[promptTokens.size() - 1]);
  167. } else if (promptTokens.size() == 1) {
  168. const std::string BOS = "<BOS>";
  169. w1 = BOS;
  170. w2 = normalize(promptTokens.back());
  171. } else {
  172. // no prompt, start from BOS + best unigram
  173. const std::string BOS = "<BOS>";
  174. w1 = BOS;
  175. w2 = bestUnigram();
  176. if (!result.empty()) result += " ";
  177. result += w2;
  178. }
  179.  
  180. for (size_t i = 0; i < numWords; ++i) {
  181. std::string next = predictNextGreedy(w1, w2);
  182. if (next.empty()) break;
  183. result += " " + next;
  184. w1 = w2;
  185. w2 = next;
  186. }
  187.  
  188. return result;
  189. }
  190.  
  191. // Generate randomly (forward direction)
  192. std::string generateRandom(const std::string& prompt,
  193. size_t numWords,
  194. unsigned int seed = std::random_device{}()) const {
  195. if (!trained_) return "[ERROR: model not trained]";
  196. if (numWords == 0) return prompt;
  197.  
  198. std::mt19937 rng(seed);
  199. auto promptTokens = tokenize(prompt);
  200. std::string result = joinTokens(promptTokens);
  201.  
  202. std::string w1, w2;
  203. if (promptTokens.size() >= 2) {
  204. w1 = normalize(promptTokens[promptTokens.size() - 2]);
  205. w2 = normalize(promptTokens[promptTokens.size() - 1]);
  206. } else if (promptTokens.size() == 1) {
  207. const std::string BOS = "<BOS>";
  208. w1 = BOS;
  209. w2 = normalize(promptTokens.back());
  210. } else {
  211. const std::string BOS = "<BOS>";
  212. w1 = BOS;
  213. w2 = randomUnigram(rng);
  214. if (!result.empty()) result += " ";
  215. result += w2;
  216. }
  217.  
  218. for (size_t i = 0; i < numWords; ++i) {
  219. std::string next = predictNextRandom(w1, w2, rng);
  220. if (next.empty()) break;
  221. result += " " + next;
  222. w1 = w2;
  223. w2 = next;
  224. }
  225.  
  226. return result;
  227. }
  228.  
  229. private:
  230. std::string normalize(const std::string& w) const {
  231. return toLower(w);
  232. }
  233.  
  234. std::string trigramKey(const std::string& w1, const std::string& w2) const {
  235. // Use a separator that won't appear in normal text
  236. static const char SEP = '\n';
  237. return w1 + SEP + w2;
  238. }
  239.  
  240. template<typename MapType>
  241. std::string bestFromMap(
  242. const std::unordered_map<std::string, MapType>& outer,
  243. const std::string& key) const {
  244.  
  245. auto it = outer.find(key);
  246. if (it == outer.end()) return "";
  247.  
  248. const MapType& inner = it->second;
  249. int best = -1;
  250. std::string bestWord;
  251. for (const auto& kv : inner) {
  252. if (kv.second > best) {
  253. best = kv.second;
  254. bestWord = kv.first;
  255. }
  256. }
  257. return bestWord;
  258. }
  259.  
  260. template<typename MapType>
  261. std::string randomFromMap(
  262. const std::unordered_map<std::string, MapType>& outer,
  263. const std::string& key,
  264. std::mt19937& rng) const {
  265.  
  266. auto it = outer.find(key);
  267. if (it == outer.end()) return "";
  268.  
  269. const MapType& inner = it->second;
  270. int total = 0;
  271. for (const auto& kv : inner) {
  272. total += kv.second;
  273. }
  274. if (total <= 0) return "";
  275.  
  276. std::uniform_int_distribution<int> dist(1, total);
  277. int r = dist(rng);
  278. for (const auto& kv : inner) {
  279. r -= kv.second;
  280. if (r <= 0) {
  281. return kv.first;
  282. }
  283. }
  284. return "";
  285. }
  286.  
  287. std::string bestUnigram() const {
  288. if (unigramCounts_.empty()) return "";
  289.  
  290. int best = -1;
  291. std::string bestWord;
  292. for (const auto& kv : unigramCounts_) {
  293. if (kv.first == "<BOS>") continue;
  294. if (kv.second > best) {
  295. best = kv.second;
  296. bestWord = kv.first;
  297. }
  298. }
  299. return bestWord;
  300. }
  301.  
  302. std::string randomUnigram(std::mt19937& rng) const {
  303. if (unigramCounts_.empty()) return "";
  304.  
  305. int total = 0;
  306. for (const auto& kv : unigramCounts_) {
  307. if (kv.first == "<BOS>") continue;
  308. total += kv.second;
  309. }
  310. if (total <= 0) return "";
  311.  
  312. std::uniform_int_distribution<int> dist(1, total);
  313. int r = dist(rng);
  314. for (const auto& kv : unigramCounts_) {
  315. if (kv.first == "<BOS>") continue;
  316. r -= kv.second;
  317. if (r <= 0) return kv.first;
  318. }
  319. return "";
  320. }
  321.  
  322. private:
  323. // unigram
  324. std::unordered_map<std::string, int> unigramCounts_;
  325. int totalUnigrams_ = 0;
  326.  
  327. // bigram: w1 -> (w2 -> count)
  328. std::unordered_map<std::string, std::map<std::string, int>> bigramCounts_;
  329.  
  330. // trigram: (w1,w2) -> (w3 -> count)
  331. std::unordered_map<std::string, std::map<std::string, int>> trigramCounts_;
  332.  
  333. bool trained_ = false;
  334. };
  335.  
  336. // =======================
  337. // Helper: read entire file
  338. // =======================
  339.  
  340. std::string readFile(const std::string& path) {
  341. std::ifstream ifs(path);
  342. if (!ifs) {
  343. throw std::runtime_error("Failed to open file: " + path);
  344. }
  345. std::ostringstream oss;
  346. oss << ifs.rdbuf();
  347. return oss.str();
  348. }
  349.  
  350. // =======================
  351. // Past-prediction helpers
  352. // =======================
  353.  
  354. // Generate past text deterministically using a backward model.
  355. // 1. Reverse prompt tokens
  356. // 2. Generate future in reversed space
  357. // 3. Strip original prompt part
  358. // 4. Reverse new tokens back into normal order
  359. std::string generatePastDeterministic(const TrigramLanguageModel& backwardModel,
  360. const std::string& prompt,
  361. size_t numWords) {
  362. auto promptTokens = tokenize(prompt);
  363. if (promptTokens.empty() || numWords == 0) {
  364. return "";
  365. }
  366.  
  367. std::vector<std::string> reversedPromptTokens = promptTokens;
  368. std::reverse(reversedPromptTokens.begin(), reversedPromptTokens.end());
  369. std::string reversedPrompt = joinTokens(reversedPromptTokens);
  370.  
  371. std::string reversedFull = backwardModel.generateDeterministic(reversedPrompt, numWords);
  372.  
  373. auto fullTokens = tokenize(reversedFull);
  374. if (fullTokens.size() <= reversedPromptTokens.size()) {
  375. return "";
  376. }
  377.  
  378. std::vector<std::string> newTokens(
  379. fullTokens.begin() + static_cast<long>(reversedPromptTokens.size()),
  380. fullTokens.end()
  381. );
  382.  
  383. std::reverse(newTokens.begin(), newTokens.end());
  384. return joinTokens(newTokens);
  385. }
  386.  
  387. std::string generatePastRandom(const TrigramLanguageModel& backwardModel,
  388. const std::string& prompt,
  389. size_t numWords) {
  390. auto promptTokens = tokenize(prompt);
  391. if (promptTokens.empty() || numWords == 0) {
  392. return "";
  393. }
  394.  
  395. std::vector<std::string> reversedPromptTokens = promptTokens;
  396. std::reverse(reversedPromptTokens.begin(), reversedPromptTokens.end());
  397. std::string reversedPrompt = joinTokens(reversedPromptTokens);
  398.  
  399. std::string reversedFull = backwardModel.generateRandom(reversedPrompt, numWords);
  400.  
  401. auto fullTokens = tokenize(reversedFull);
  402. if (fullTokens.size() <= reversedPromptTokens.size()) {
  403. return "";
  404. }
  405.  
  406. std::vector<std::string> newTokens(
  407. fullTokens.begin() + static_cast<long>(reversedPromptTokens.size()),
  408. fullTokens.end()
  409. );
  410.  
  411. std::reverse(newTokens.begin(), newTokens.end());
  412. return joinTokens(newTokens);
  413. }
  414.  
  415. // =======
  416. // main()
  417. // =======
  418. //
  419. // Usage:
  420. // ./time_machine_trigram_bidirectional article.txt
  421. //
  422.  
  423. int main(int argc, char* argv[]) {
  424. if (argc < 2) {
  425. std::cerr << "Usage: " << argv[0] << " <article.txt>\n";
  426. return 1;
  427. }
  428.  
  429. const std::string articlePath = argv[1];
  430. std::string articleText;
  431.  
  432. try {
  433. articleText = readFile(articlePath);
  434. } catch (const std::exception& ex) {
  435. std::cerr << "Error reading article: " << ex.what() << "\n";
  436. return 1;
  437. }
  438.  
  439. // Forward model (predicts future)
  440. TrigramLanguageModel forwardModel;
  441. forwardModel.train(articleText);
  442.  
  443. if (!forwardModel.isTrained()) {
  444. std::cerr << "Forward model failed to train (article too short?).\n";
  445. return 1;
  446. }
  447.  
  448. // Backward model (predicts past)
  449. std::string reversedArticle = reverseTextTokens(articleText);
  450. TrigramLanguageModel backwardModel;
  451. backwardModel.train(reversedArticle);
  452.  
  453. if (!backwardModel.isTrained()) {
  454. std::cerr << "Backward model failed to train.\n";
  455. return 1;
  456. }
  457.  
  458. std::cout << "Time Machine Text Predictor (Bidirectional)\n";
  459. std::cout << "Trained on: " << articlePath << "\n\n";
  460.  
  461. std::cout << "Enter a prompt (some snippet from the article or similar text):\n> ";
  462. std::string prompt;
  463. std::getline(std::cin, prompt);
  464.  
  465. if (prompt.empty()) {
  466. std::cout << "Empty prompt, using default 'the'.\n";
  467. prompt = "the";
  468. }
  469.  
  470. std::size_t numWords = 20;
  471. std::cout << "How many words to predict (future/past)? [default 20]: ";
  472. {
  473. std::string line;
  474. std::getline(std::cin, line);
  475. if (!line.empty()) {
  476. try {
  477. numWords = static_cast<std::size_t>(std::stoul(line));
  478. } catch (...) {
  479. std::cout << "Invalid number, using default 20.\n";
  480. numWords = 20;
  481. }
  482. }
  483. }
  484.  
  485. std::cout << "\n=== FUTURE (forward model, deterministic) ===\n";
  486. std::string futureDet = forwardModel.generateDeterministic(prompt, numWords);
  487. std::cout << futureDet << "\n\n";
  488.  
  489. std::cout << "=== FUTURE (forward model, random) ===\n";
  490. std::string futureRnd = forwardModel.generateRandom(prompt, numWords);
  491. std::cout << futureRnd << "\n\n";
  492.  
  493. std::cout << "=== PAST (backward model, deterministic) ===\n";
  494. std::string pastDet = generatePastDeterministic(backwardModel, prompt, numWords);
  495. std::cout << pastDet << "\n\n";
  496.  
  497. std::cout << "=== PAST (backward model, random) ===\n";
  498. std::string pastRnd = generatePastRandom(backwardModel, prompt, numWords);
  499. std::cout << pastRnd << "\n";
  500.  
  501. return 0;
  502. }
  503.  
Advertisement
Add Comment
Please, Sign In to add comment