Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include "ASTNode.h"
- #include <utility>
- #include <iostream>
- #include <chrono>
- SymbolTable ASTNode::globalTable = SymbolTable();
- const std::shared_ptr<Variable> ASTNode::nullVarPtr = std::make_shared<Variable>(null);
- ASTNode::ASTNode(ASTNodeType type, std::string name): type(type), name(std::move(name)) {}
- ASTNode::ASTNode(const ASTNode &other) = default;
- ASTNode::~ASTNode() {
- for (ASTNode* child: children) {
- delete child;
- }
- }
- std::shared_ptr<Variable> ASTNode::eval() {
- ASTNodeCleaner cleaner = ASTNodeCleaner(this);
- switch (type) {
- case ROOT:
- case PROG:
- case STMT:
- for (ASTNode* child : children) {
- if (child->type == RET) {
- return child->eval();
- } else {
- child->eval();
- }
- }
- break;
- case WHILE:
- {
- ASTNode* condition = children.at(0);
- ASTNode* block = children.at(1);
- // I have no idea why this works without making it a pointer to the std::any value like I do everywhere else.
- // Must just be black magic.
- while (std::any_cast<bool>(condition->eval()->getValue())) {
- block->eval();
- }
- }
- break;
- case FOR:
- {
- ASTNode* init = children.at(0);
- ASTNode* condition = children.at(1);
- ASTNode* iter = children.at(2);
- ASTNode* block = children.at(3);
- init->eval();
- while (true) {
- if (std::any_cast<bool>(condition->eval()->getValue())) {
- block->eval();
- iter->eval();
- } else break;
- }
- }
- break;
- case IF:
- {
- auto condition = children.front()->eval();
- if (std::any_cast<bool>(condition->getValue())) {
- if (children.size() > 1) children.at(1)->eval();
- return std::make_shared<Variable>(boole, true);
- } else {
- bool metTrueCondition = false;
- for (int i = 1; i < children.size(); i++) {
- ASTNode* child = children.at(i);
- if (child->type == ELIF) {
- std::shared_ptr<Variable> childCondition = child->eval();
- if (std::any_cast<bool>(childCondition->getValue())) {
- metTrueCondition = true;
- break;
- }
- }
- }
- if (!metTrueCondition) {
- for (int i = 2; i < children.size(); i++) {
- ASTNode* child = children.at(i);
- if (child->type == STMT) {
- child->eval();
- break;
- }
- }
- }
- }
- }
- break;
- case ELIF:
- {
- auto condition = children.front()->eval();
- if (std::any_cast<bool>(condition->getValue())) {
- if (children.size() > 1) children.at(1)->eval();
- return std::make_shared<Variable>(boole, true);
- }
- }
- break;
- case CALL:
- if (name == "io.print") {
- for (ASTNode* child : children) {
- printf(child->eval()->toString().c_str());
- }
- } else if (name == "io.println") {
- for (ASTNode* child : children) {
- printf(child->eval()->toString().c_str());
- }
- printf("\n");
- } else if (name == "sys.clock") {
- long long ms = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
- return std::make_shared<Variable>(i64, ms);
- }
- break;
- case LKP:
- if (globalTable.hasSymbol(name)) {
- return std::any_cast<std::shared_ptr<Variable>>(globalTable.getSymbol(name));
- } else {
- return std::any_cast<std::shared_ptr<Variable>>(getSymbol(name));
- }
- break;
- case OP:
- {
- auto left = children.at(0)->eval();
- if (name == "+") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->add(*right));
- } else if (name == "-") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->sub(*right));
- } else if (name == "*") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->mul(*right));
- } else if (name == "/") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->div(*right));
- } else if (name == "%") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->mod(*right));
- } else if (name == "==") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->equals(*right));
- } else if (name == "!=") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->notEquals(*right));
- } else if (name == "=") {
- auto right = children.at(1)->eval();
- left->setValue(right->getValue());
- } else if (name == "&&") {
- auto right = children.at(1)->eval();
- auto leftBool = std::any_cast<bool>(left->getValue());
- auto rightBool = std::any_cast<bool>(right->getValue());
- return std::make_shared<Variable>(boole, leftBool && rightBool);
- } else if (name == "||") {
- auto right = children.at(1)->eval();
- auto leftBool = std::any_cast<bool>(left->getValue());
- auto rightBool = std::any_cast<bool>(right->getValue());
- return std::make_shared<Variable>(boole, leftBool || rightBool);
- } else if (name == ">") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->greater(*right));
- } else if (name == "<") {
- auto right = children.at(1)->eval();
- return std::make_shared<Variable>(left->less(*right));
- } else if (name == "!") {
- return std::make_shared<Variable>(boole, !std::any_cast<bool>(left->getValue()));
- }
- break;
- }
- case DECL:
- {
- VariableType type = children.front()->eval()->getType();
- std::shared_ptr<Variable> newVarPtr = std::make_shared<Variable>(type);
- ASTNode* container = getScope();
- container->localTable.addSymbol(name, newVarPtr);
- return newVarPtr;
- }
- case TINF:
- return std::make_shared<Variable>(Variable::resolveTypeString(name));
- break;
- case RET:
- return children.front()->eval();
- break;
- }
- return nullVarPtr;
- }
- void ASTNode::addChild(ASTNode* node) {
- children.push_back(node);
- node->parent = this;
- }
- ASTNode* ASTNode::getParent() {
- return parent;
- }
- std::vector<ASTNode*> ASTNode::getChildren() {
- return children;
- }
- void ASTNode::printRecursive(int depth, std::vector<int> continues) {
- std::string typeName;
- switch (type) {
- case ROOT:
- typeName = "ROOT";
- break;
- case PROG:
- typeName = "PROG";
- break;
- case STMT:
- typeName = "STMT";
- break;
- case FOR:
- typeName = "FOR";
- break;
- case WHILE:
- typeName = "WHILE";
- break;
- case IF:
- typeName = "IF";
- break;
- case ELIF:
- typeName = "ELIF";
- break;
- case CALL:
- typeName = "CALL";
- break;
- case LKP:
- typeName = "LKP";
- break;
- case DECL:
- typeName = "DECL";
- break;
- case TINF:
- typeName = "TINF";
- break;
- case DESTR:
- typeName = "DESTR";
- break;
- case OP:
- typeName = "OP";
- break;
- case RET:
- typeName = "RET";
- break;
- }
- std::cout << name << ", " << typeName << ", " << children.size() << ", " << this << std::endl;
- continues.push_back(depth);
- for (ASTNode* child : children) {
- int startPoint = 0;
- int highest = 0;
- for (int i : continues) if (i > highest) highest = i;
- for (int i = 0; i < depth; i++) {
- if (std::find(continues.begin(), continues.end(), i + 1) != continues.end()) {
- if (i + 1 != highest) {
- std::cout << "| ";
- } else {
- std::cout << "|---";
- if (i + 1 == depth) {
- std::cout << ">";
- } else {
- std::cout << "-";
- }
- startPoint = i;
- break;
- }
- } else {
- std::cout << " ";
- }
- }
- for (int i = startPoint; i < depth - 1; i++) {
- std::cout << "-----";
- }
- if (startPoint + 1 != depth) std::cout << "---->";
- if (child == children.back()) {
- auto iterator = std::find(continues.begin(), continues.end(), depth);
- if (iterator != continues.end()) continues.erase(iterator);
- }
- child->printRecursive(depth + 1, continues);
- }
- }
- std::any ASTNode::getSymbol(std::string name) {
- if (localTable.hasSymbol(name)) {
- return localTable.getSymbol(name);
- } else if (type != ROOT && parent != nullptr) {
- return parent->getSymbol(name);
- } else {
- return nullptr;
- }
- }
- void ASTNode::cleanLocalTable() {
- localTable.deleteSymbols();
- }
- ASTNode *ASTNode::getScope() {
- if (type <= ELIF) {
- return this;
- } else if (type != ROOT) {
- return parent->getScope();
- } else {
- return nullptr;
- }
- }
- ASTNodeCleaner::ASTNodeCleaner(ASTNode *target): target(target) {}
- ASTNodeCleaner::~ASTNodeCleaner() {
- if (target->type <= ELIF) {
- target->cleanLocalTable();
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement