Advertisement
minh_tran_782

PPL TYPE 5

Apr 3rd, 2023
991
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.03 KB | None | 0 0
  1. class Type(ABC): pass
  2. class IntType(Type): pass
  3. class FloatType(Type): pass
  4. class BoolType(Type): pass
  5. class NoType(Type): pass
  6. class FunctionType(Type): pass
  7. class Symbol:
  8.     def __init__(self, name, typ):
  9.         self.name = name
  10.         if type(typ) is list:
  11.             self.typ = typ
  12.             return
  13.         self.typ = [typ]
  14. class Utils:
  15.     def infer(env, name, typ):
  16.         for symbol in env:
  17.                 if symbol.name == name:
  18.                     symbol.typ = [typ]
  19.                     return typ
  20. class GetType:
  21.     def checkType(s):
  22.         if s == "IntType":
  23.             return IntType()
  24.         if s == "FloatType":
  25.             return FloatType()
  26.         if s == "BoolType":
  27.             return BoolType()
  28.         return NoType()
  29.  
  30. class UtilsInMoreEnvironments:
  31.     def infer(env, name, typ):
  32.         for symbol_list in env:
  33.             for symbol in symbol_list:
  34.                 if symbol.name == name:
  35.                     symbol.typ = [typ]
  36.                     return typ
  37. class GetEnv(Visitor):
  38.     def visitProgram(self,ctx:Program,o):
  39.         o = []
  40.         for decl in ctx.decl:
  41.             o = self.visit(decl,o)
  42.         return o
  43.     def visitVarDecl(self,ctx:VarDecl,o):
  44.         for x in o:
  45.             if ctx.name == x.name:
  46.                 raise Redeclared(ctx)
  47.         o += [Symbol(ctx.name, NoType())]
  48.         return o        
  49.    
  50.     def visitFuncDecl(self,ctx:FuncDecl,o):
  51.         for x in o:
  52.             if ctx.name == x.name:
  53.                 raise Redeclared(ctx)
  54.         env = []
  55.         for decl in ctx.param:
  56.             env = self.visit(decl,env)
  57.         for decl in ctx.local:
  58.             env = self.visit(decl,env)
  59.         o += [Symbol(ctx.name,[FunctionType(),env])]
  60.         return o
  61. class StaticCheck(Visitor):
  62.     def visitProgram(self,ctx:Program,o):
  63.         env = GetEnv().visit(ctx,o)
  64.         for decl in ctx.decl:
  65.             if type(decl) is FuncDecl:
  66.                 self.visit(decl, env)
  67.         for stmt in ctx.stmts:
  68.             self.visit(stmt, env)
  69. # class FuncDecl(Decl): #name:str,param:List[VarDecl],local:List[Decl],stmts:List[Stmt]
  70.     def visitFuncDecl(self,ctx:FuncDecl,o):
  71.         for x in o:
  72.             if ctx.name == x.name:
  73.                 if type(x.typ[0]) is not FunctionType: return
  74.                 for stmt in ctx.stmts:
  75.                     if type(stmt) is CallStmt:
  76.                         self.visit(stmt,o)
  77.                     if type(stmt) is Assign:
  78.                         self.visit(stmt,x.typ[1])
  79.                 return
  80.         return
  81.     def visitAssign(self,ctx:Assign,o):
  82.         rtype = self.visit(ctx.rhs, o)
  83.         ltype = self.visit(ctx.lhs, o)
  84.        
  85.         if type(rtype) is NoType and type(ltype) is NoType:
  86.             raise TypeCannotBeInferred(ctx)
  87.            
  88.         if type(ltype) is NoType:
  89.             Utils.infer(o, ctx.lhs.name, rtype)
  90.             return
  91.            
  92.         if type(rtype) is NoType:
  93.             Utils.infer(o, ctx.rhs.name, ltype)
  94.             return
  95.            
  96.         if type(ltype) is type(rtype): return
  97.         raise TypeMismatchInStatement(ctx)
  98. # class CallStmt(Stmt): #name:str,args:List[Exp]
  99.     def visitCallStmt(self,ctx:CallStmt,o):
  100.         #args la id: No o dau ? Global or Function Scope.
  101.         #args ko co type. Chi co Literal -> Type hoac Id -> Type.
  102.         args = ctx.args
  103.         for x in o:
  104.             if ctx.name == x.name:
  105.                 if type(x.typ[0]) is not FunctionType:
  106.                     raise UndeclaredIdentifier (ctx.name)
  107.                 else:
  108.                     if len(args) != len(x.typ[1]):
  109.                         raise TypeMismatchInStatement (ctx)
  110.                     for i in range(0,len(args)):
  111.                     # x.typ[1] is Function Enviroment
  112.                     # o is global environment, which contains Function
  113.                     # up
  114.                         env = o
  115.                         type_of_args = type(self.visit(args[i],env))
  116.                         type_of_param = type(x.typ[1][i].typ[0])
  117.                         # raise UndeclaredIdentifier (str(type_of_args))
  118.                         if type_of_args is NoType and type_of_param is NoType:
  119.                             raise TypeCannotBeInferred (ctx)
  120.                            
  121.                         if type_of_args is NoType and type_of_param is not NoType:
  122.                            
  123.                             UtilsInMoreEnvironments.infer([x.typ[1],o],args[i].name,type_of_param)
  124.                        
  125.                         if type_of_args is not NoType and type_of_param is NoType:
  126.                            
  127.                             Utils.infer(x.typ[1],x.typ[1][i].name,type_of_args)
  128.                        
  129.                         if type_of_args != type_of_param:
  130.                            
  131.                             raise TypeMismatchInStatement (ctx)
  132.                     return
  133.         raise UndeclaredIdentifier(ctx.name)
  134. # Case 1: FuncVar is NoType, while Call variable Type -> Suy dien for CallVariable
  135.        
  136. # Case 2: FuncVar have Type, call var have type -> Check if they are same type
  137. # Case 3: Both is no type -> TypeCannotBeInferred
  138.  
  139.     def visitBinOp(self,ctx:BinOp,o):
  140.         e1t = self.visit(ctx.e1, o)
  141.         e2t = self.visit(ctx.e2, o)
  142.        
  143.         if ctx.op in ['+', '-', '*', '/']:
  144.             if type(e1t) is NoType:
  145.                 e1t = Utils.infer(o, ctx.e1.name, IntType())
  146.             if type(e2t) is NoType:
  147.                 e2t = Utils.infer(o, ctx.e2.name, IntType())
  148.             if type(e1t) is IntType and type(e2t) is IntType:
  149.                 return IntType()
  150.             raise TypeMismatchInExpression(ctx)
  151.            
  152.         if ctx.op in ['+.', '-.', '*.', '/.']:
  153.             if type(e1t) is NoType:
  154.                 e1t = Utils.infer(o, ctx.e1.name, FloatType())
  155.             if type(e2t) is NoType:
  156.                 e2t = Utils.infer(o, ctx.e2.name, FloatType())
  157.             if type(e2t) is NoType:
  158.                 e1t = Utils.infer(o, ctx.e1.name, FloatType())
  159.                 e2t = Utils.infer(o, ctx.e2.name, FloatType())
  160.             if type(e1t) is FloatType and type(e2t) is FloatType:
  161.                 return FloatType()
  162.             raise TypeMismatchInExpression(ctx)
  163.         if ctx.op in ['>', '=']:
  164.             if type(e1t) is NoType:
  165.                 e1t = Utils.infer(o, ctx.e1.name, IntType())
  166.             if type(e2t) is NoType and type(e1t) is IntType:
  167.                 e2t = Utils.infer(o, ctx.e2.name, IntType())
  168.             if type(e1t) is IntType and type(e2t) is IntType:
  169.                 return BoolType()
  170.         if ctx.op in ['>.', '=.']:
  171.             if type(e1t) is NoType:
  172.                 e1t = Utils.infer(o, ctx.e1.name, FloatType())
  173.             if type(e2t) is NoType:
  174.                 e2t = Utils.infer(o, ctx.e2.name, FloatType())
  175.             if type(e1t) is FloatType and type(e2t) is FloatType:
  176.                 return BoolType()
  177.             raise TypeMismatchInExpression(ctx)
  178.         if ctx.op in ['&&', '||', '>b' , '=b']:
  179.             if type(e1t) is NoType:
  180.                 e1t = Utils.infer(o, ctx.e1.name, BoolType())
  181.             if type(e2t) is NoType:
  182.                 e2t = Utils.infer(o, ctx.e2.name, BoolType())
  183.             if type(e1t) is BoolType and type(e2t) is BoolType:
  184.                 return BoolType()
  185.             raise TypeMismatchInExpression(ctx)
  186.  
  187.     def visitUnOp(self,ctx:UnOp,o):
  188.         e1t = self.visit(ctx.e, o)
  189.         if ctx.op in ['-']:
  190.             if type(e1t) is NoType:
  191.                 e1t = Utils.infer(o, ctx.e.name, IntType())
  192.             if type(e1t) is IntType:
  193.                 return IntType()
  194.             raise TypeMismatchInExpression(ctx)
  195.  
  196.         if ctx.op in ['-.']:
  197.             if type(e1t) is NoType:
  198.                 e1t = Utils.infer(o, ctx.e.name, FloatType())
  199.             if type(e1t) is FloatType:
  200.                 return FloatType()
  201.             raise TypeMismatchInExpression(ctx)
  202.  
  203.         if ctx.op in ['i2f']:
  204.             if type(e1t) is NoType:
  205.                 e1t = Utils.infer(o, ctx.e.name, IntType())
  206.             if type(e1t) == IntType:
  207.                 return FloatType()
  208.             raise TypeMismatchInExpression(ctx)
  209.  
  210.         if ctx.op in ['!']:
  211.             if type(e1t) is NoType:
  212.                 e1t = Utils.infer(o, ctx.e.name, BoolType())
  213.             if type(e1t) == BoolType:
  214.                 return BoolType()
  215.             raise TypeMismatchInExpression(ctx)
  216.    
  217.         if ctx.op in ['floor']:
  218.             if type(e1t) is NoType:
  219.                 e1t = Utils.infer(o, ctx.e.name, FloatType())
  220.             if type(e1t) is FloatType:
  221.                 return IntType()
  222.             raise TypeMismatchInExpression(ctx)
  223.    
  224.     def visitIntLit(self,ctx:IntLit,o): return IntType()
  225.  
  226.     def visitFloatLit(self,ctx,o): return FloatType()
  227.  
  228.     def visitBoolLit(self,ctx,o): return BoolType()
  229.  
  230.     def visitId(self,ctx,o):
  231.         for symbol in o:
  232.                 if ctx.name == symbol.name:
  233.                     return symbol.typ[0]
  234.         raise UndeclaredIdentifier(ctx.name)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement