Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <fstream>
- #include <sstream>
- #include <unordered_map>
- #include <map>
- #include <vector>
- #include <string>
- #include <algorithm>
- #include <random>
- #include <cctype>
- #include <stdexcept>
- #include <limits>
- #include <cmath> // for log, exp
- // ANDREW JUSTIN SOLESA
- // =======================
- // Utility: String Helpers
- // =======================
- std::string toLower(const std::string& s) {
- std::string out = s;
- std::transform(out.begin(), out.end(), out.begin(),
- [](unsigned char c) { return std::tolower(c); });
- return out;
- }
- // Strip common punctuation at start/end of a token
- std::string stripPunct(const std::string& s) {
- size_t start = 0;
- size_t end = s.size();
- while (start < end && std::ispunct(static_cast<unsigned char>(s[start])) &&
- s[start] != '\'' && s[start] != '-') {
- ++start;
- }
- while (end > start && std::ispunct(static_cast<unsigned char>(s[end - 1])) &&
- s[end - 1] != '\'' && s[end - 1] != '-') {
- --end;
- }
- if (start >= end) return "";
- return s.substr(start, end - start);
- }
- std::vector<std::string> tokenize(const std::string& text) {
- std::istringstream iss(text);
- std::vector<std::string> tokens;
- std::string token;
- while (iss >> token) {
- token = stripPunct(token);
- if (!token.empty()) {
- tokens.push_back(token);
- }
- }
- return tokens;
- }
- std::string joinTokens(const std::vector<std::string>& v) {
- std::string out;
- for (size_t i = 0; i < v.size(); ++i) {
- if (i > 0) out += " ";
- out += v[i];
- }
- return out;
- }
- // Reverse the token order in a text
- std::string reverseTextTokens(const std::string& text) {
- auto tokens = tokenize(text);
- std::reverse(tokens.begin(), tokens.end());
- return joinTokens(tokens);
- }
- // =======================
- // Trigram Language Model
- // (forward or backward)
- // =======================
- class TrigramLanguageModel {
- public:
- void train(const std::string& text) {
- auto tokens = tokenize(text);
- if (tokens.empty()) return;
- const std::string BOS = "<BOS>";
- std::vector<std::string> seq;
- seq.reserve(tokens.size() + 1);
- seq.push_back(BOS);
- for (auto& t : tokens) {
- seq.push_back(normalize(t));
- }
- if (seq.size() < 2) return;
- for (size_t i = 0; i < seq.size(); ++i) {
- const std::string& w = seq[i];
- if (w.empty()) continue;
- unigramCounts_[w] += 1;
- totalUnigrams_ += 1;
- if (i + 1 < seq.size()) {
- const std::string& w2 = seq[i + 1];
- if (!w2.empty()) {
- bigramCounts_[w][w2] += 1;
- }
- }
- if (i + 2 < seq.size()) {
- const std::string& w2 = seq[i + 1];
- const std::string& w3 = seq[i + 2];
- if (!w2.empty() && !w3.empty()) {
- std::string key = trigramKey(w, w2);
- trigramCounts_[key][w3] += 1;
- }
- }
- }
- trained_ = true;
- }
- bool isTrained() const { return trained_; }
- // ========= Prediction (for generation) =========
- // Deterministic prediction with backoff: trigram -> bigram -> unigram
- std::string predictNextGreedy(const std::string& w1,
- const std::string& w2) const {
- std::string key = trigramKey(normalize(w1), normalize(w2));
- std::string next = bestFromMap(trigramCounts_, key);
- if (!next.empty()) return next;
- // backoff to bigram using w2 as context
- next = bestFromMap(bigramCounts_, normalize(w2));
- if (!next.empty()) return next;
- // backoff to most common unigram
- return bestUnigram();
- }
- // Random, weighted by counts with backoff
- std::string predictNextRandom(const std::string& w1,
- const std::string& w2,
- std::mt19937& rng) const {
- std::string key = trigramKey(normalize(w1), normalize(w2));
- std::string next = randomFromMap(trigramCounts_, key, rng);
- if (!next.empty()) return next;
- next = randomFromMap(bigramCounts_, normalize(w2), rng);
- if (!next.empty()) return next;
- return randomUnigram(rng);
- }
- // Generate deterministically (forward direction)
- std::string generateDeterministic(const std::string& prompt,
- size_t numWords) const {
- if (!trained_) return "[ERROR: model not trained]";
- if (numWords == 0) return prompt;
- auto promptTokens = tokenize(prompt);
- std::string result = joinTokens(promptTokens);
- // choose context
- std::string w1, w2;
- if (promptTokens.size() >= 2) {
- w1 = normalize(promptTokens[promptTokens.size() - 2]);
- w2 = normalize(promptTokens[promptTokens.size() - 1]);
- } else if (promptTokens.size() == 1) {
- const std::string BOS = "<BOS>";
- w1 = BOS;
- w2 = normalize(promptTokens.back());
- } else {
- // no prompt, start from BOS + best unigram
- const std::string BOS = "<BOS>";
- w1 = BOS;
- w2 = bestUnigram();
- if (!result.empty()) result += " ";
- result += w2;
- }
- for (size_t i = 0; i < numWords; ++i) {
- std::string next = predictNextGreedy(w1, w2);
- if (next.empty()) break;
- result += " " + next;
- w1 = w2;
- w2 = next;
- }
- return result;
- }
- // Generate randomly (forward direction)
- std::string generateRandom(const std::string& prompt,
- size_t numWords,
- unsigned int seed = std::random_device{}()) const {
- if (!trained_) return "[ERROR: model not trained]";
- if (numWords == 0) return prompt;
- std::mt19937 rng(seed);
- auto promptTokens = tokenize(prompt);
- std::string result = joinTokens(promptTokens);
- std::string w1, w2;
- if (promptTokens.size() >= 2) {
- w1 = normalize(promptTokens[promptTokens.size() - 2]);
- w2 = normalize(promptTokens[promptTokens.size() - 1]);
- } else if (promptTokens.size() == 1) {
- const std::string BOS = "<BOS>";
- w1 = BOS;
- w2 = normalize(promptTokens.back());
- } else {
- const std::string BOS = "<BOS>";
- w1 = BOS;
- w2 = randomUnigram(rng);
- if (!result.empty()) result += " ";
- result += w2;
- }
- for (size_t i = 0; i < numWords; ++i) {
- std::string next = predictNextRandom(w1, w2, rng);
- if (next.empty()) break;
- result += " " + next;
- w1 = w2;
- w2 = next;
- }
- return result;
- }
- private:
- std::string normalize(const std::string& w) const {
- return toLower(w);
- }
- std::string trigramKey(const std::string& w1, const std::string& w2) const {
- // Use a separator that won't appear in normal text
- static const char SEP = '\n';
- return w1 + SEP + w2;
- }
- template<typename MapType>
- std::string bestFromMap(
- const std::unordered_map<std::string, MapType>& outer,
- const std::string& key) const {
- auto it = outer.find(key);
- if (it == outer.end()) return "";
- const MapType& inner = it->second;
- int best = -1;
- std::string bestWord;
- for (const auto& kv : inner) {
- if (kv.second > best) {
- best = kv.second;
- bestWord = kv.first;
- }
- }
- return bestWord;
- }
- template<typename MapType>
- std::string randomFromMap(
- const std::unordered_map<std::string, MapType>& outer,
- const std::string& key,
- std::mt19937& rng) const {
- auto it = outer.find(key);
- if (it == outer.end()) return "";
- const MapType& inner = it->second;
- int total = 0;
- for (const auto& kv : inner) {
- total += kv.second;
- }
- if (total <= 0) return "";
- std::uniform_int_distribution<int> dist(1, total);
- int r = dist(rng);
- for (const auto& kv : inner) {
- r -= kv.second;
- if (r <= 0) {
- return kv.first;
- }
- }
- return "";
- }
- std::string bestUnigram() const {
- if (unigramCounts_.empty()) return "";
- int best = -1;
- std::string bestWord;
- for (const auto& kv : unigramCounts_) {
- if (kv.first == "<BOS>") continue;
- if (kv.second > best) {
- best = kv.second;
- bestWord = kv.first;
- }
- }
- return bestWord;
- }
- std::string randomUnigram(std::mt19937& rng) const {
- if (unigramCounts_.empty()) return "";
- int total = 0;
- for (const auto& kv : unigramCounts_) {
- if (kv.first == "<BOS>") continue;
- total += kv.second;
- }
- if (total <= 0) return "";
- std::uniform_int_distribution<int> dist(1, total);
- int r = dist(rng);
- for (const auto& kv : unigramCounts_) {
- if (kv.first == "<BOS>") continue;
- r -= kv.second;
- if (r <= 0) return kv.first;
- }
- return "";
- }
- private:
- // unigram
- std::unordered_map<std::string, int> unigramCounts_;
- int totalUnigrams_ = 0;
- // bigram: w1 -> (w2 -> count)
- std::unordered_map<std::string, std::map<std::string, int>> bigramCounts_;
- // trigram: (w1,w2) -> (w3 -> count)
- std::unordered_map<std::string, std::map<std::string, int>> trigramCounts_;
- bool trained_ = false;
- };
- // =======================
- // Helper: read entire file
- // =======================
- std::string readFile(const std::string& path) {
- std::ifstream ifs(path);
- if (!ifs) {
- throw std::runtime_error("Failed to open file: " + path);
- }
- std::ostringstream oss;
- oss << ifs.rdbuf();
- return oss.str();
- }
- // =======================
- // Past-prediction helpers
- // =======================
- // Generate past text deterministically using a backward model.
- // 1. Reverse prompt tokens
- // 2. Generate future in reversed space
- // 3. Strip original prompt part
- // 4. Reverse new tokens back into normal order
- std::string generatePastDeterministic(const TrigramLanguageModel& backwardModel,
- const std::string& prompt,
- size_t numWords) {
- auto promptTokens = tokenize(prompt);
- if (promptTokens.empty() || numWords == 0) {
- return "";
- }
- std::vector<std::string> reversedPromptTokens = promptTokens;
- std::reverse(reversedPromptTokens.begin(), reversedPromptTokens.end());
- std::string reversedPrompt = joinTokens(reversedPromptTokens);
- std::string reversedFull = backwardModel.generateDeterministic(reversedPrompt, numWords);
- auto fullTokens = tokenize(reversedFull);
- if (fullTokens.size() <= reversedPromptTokens.size()) {
- return "";
- }
- std::vector<std::string> newTokens(
- fullTokens.begin() + static_cast<long>(reversedPromptTokens.size()),
- fullTokens.end()
- );
- std::reverse(newTokens.begin(), newTokens.end());
- return joinTokens(newTokens);
- }
- std::string generatePastRandom(const TrigramLanguageModel& backwardModel,
- const std::string& prompt,
- size_t numWords) {
- auto promptTokens = tokenize(prompt);
- if (promptTokens.empty() || numWords == 0) {
- return "";
- }
- std::vector<std::string> reversedPromptTokens = promptTokens;
- std::reverse(reversedPromptTokens.begin(), reversedPromptTokens.end());
- std::string reversedPrompt = joinTokens(reversedPromptTokens);
- std::string reversedFull = backwardModel.generateRandom(reversedPrompt, numWords);
- auto fullTokens = tokenize(reversedFull);
- if (fullTokens.size() <= reversedPromptTokens.size()) {
- return "";
- }
- std::vector<std::string> newTokens(
- fullTokens.begin() + static_cast<long>(reversedPromptTokens.size()),
- fullTokens.end()
- );
- std::reverse(newTokens.begin(), newTokens.end());
- return joinTokens(newTokens);
- }
- // =======
- // main()
- // =======
- //
- // Usage:
- // ./time_machine_trigram_bidirectional article.txt
- //
- int main(int argc, char* argv[]) {
- if (argc < 2) {
- std::cerr << "Usage: " << argv[0] << " <article.txt>\n";
- return 1;
- }
- const std::string articlePath = argv[1];
- std::string articleText;
- try {
- articleText = readFile(articlePath);
- } catch (const std::exception& ex) {
- std::cerr << "Error reading article: " << ex.what() << "\n";
- return 1;
- }
- // Forward model (predicts future)
- TrigramLanguageModel forwardModel;
- forwardModel.train(articleText);
- if (!forwardModel.isTrained()) {
- std::cerr << "Forward model failed to train (article too short?).\n";
- return 1;
- }
- // Backward model (predicts past)
- std::string reversedArticle = reverseTextTokens(articleText);
- TrigramLanguageModel backwardModel;
- backwardModel.train(reversedArticle);
- if (!backwardModel.isTrained()) {
- std::cerr << "Backward model failed to train.\n";
- return 1;
- }
- std::cout << "Time Machine Text Predictor (Bidirectional)\n";
- std::cout << "Trained on: " << articlePath << "\n\n";
- std::cout << "Enter a prompt (some snippet from the article or similar text):\n> ";
- std::string prompt;
- std::getline(std::cin, prompt);
- if (prompt.empty()) {
- std::cout << "Empty prompt, using default 'the'.\n";
- prompt = "the";
- }
- std::size_t numWords = 20;
- std::cout << "How many words to predict (future/past)? [default 20]: ";
- {
- std::string line;
- std::getline(std::cin, line);
- if (!line.empty()) {
- try {
- numWords = static_cast<std::size_t>(std::stoul(line));
- } catch (...) {
- std::cout << "Invalid number, using default 20.\n";
- numWords = 20;
- }
- }
- }
- std::cout << "\n=== FUTURE (forward model, deterministic) ===\n";
- std::string futureDet = forwardModel.generateDeterministic(prompt, numWords);
- std::cout << futureDet << "\n\n";
- std::cout << "=== FUTURE (forward model, random) ===\n";
- std::string futureRnd = forwardModel.generateRandom(prompt, numWords);
- std::cout << futureRnd << "\n\n";
- std::cout << "=== PAST (backward model, deterministic) ===\n";
- std::string pastDet = generatePastDeterministic(backwardModel, prompt, numWords);
- std::cout << pastDet << "\n\n";
- std::cout << "=== PAST (backward model, random) ===\n";
- std::string pastRnd = generatePastRandom(backwardModel, prompt, numWords);
- std::cout << pastRnd << "\n";
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment