From cb30ad70401ce1436b59c4cb6cf2c2dae2970ce6 Mon Sep 17 00:00:00 2001 From: Andrew Segavac Date: Wed, 12 May 2021 06:40:11 -0600 Subject: [PATCH] added type checking --- boring/main.py | 27 +++--- boring/parse.py | 89 +++++++---------- boring/type_checker.py | 133 +++++++++++++++++++------- boring/type_checking.py | 205 ++++++++++++++++++++++++++++++++++++++++ boring/typing.py | 64 ------------- 5 files changed, 355 insertions(+), 163 deletions(-) create mode 100644 boring/type_checking.py delete mode 100644 boring/typing.py diff --git a/boring/main.py b/boring/main.py index 0579543..c5db87f 100644 --- a/boring/main.py +++ b/boring/main.py @@ -1,24 +1,29 @@ import sys from typing import List from boring.parse import boring_parser, TreeToBoring, pretty_print -from boring.type_checker import TypeCheckTableBuilder, TypeComparison, check_types +from boring.type_checking import TypeChecker if __name__ == "__main__": with open(sys.argv[1]) as f: tree = boring_parser.parse(f.read()) # print(tree) result = TreeToBoring().transform(tree) + # pretty_print(result) + type_checker = TypeChecker() + while type_checker.with_module({}, result): + print('loop') + # type_checker.with_module({}, result) pretty_print(result) - tctb = TypeCheckTableBuilder() - table: List[TypeComparison] = [] - tctb.with_module({}, table, result) - for e in table: - print(e) - print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^') - check_types(table) - print('vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv') - for e in table: - print(e) + # tctb = TypeCheckTableBuilder() + # table: List[TypeComparison] = [] + # tctb.with_module({}, table, result) + # for e in table: + # print(e) + # print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") + # check_types(table) + # print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv") + # for e in table: + # print(e) # None, Some # None skip, set diff --git a/boring/parse.py b/boring/parse.py index 5a29254..d555777 100644 --- a/boring/parse.py +++ b/boring/parse.py @@ -35,7 +35,9 @@ class Identifier: @dataclass class FunctionTypeUsage: - arguments: List["TypeUsage"] # Specified if it is a function, this is how you tell if it's a function + arguments: List[ + "TypeUsage" + ] # Specified if it is a function, this is how you tell if it's a function return_type: "TypeUsage" @@ -56,40 +58,39 @@ class Operator(enum.Enum): @dataclass class LiteralInt: - id: str value: int + type: Optional[TypeUsage] @dataclass class FunctionCall: - id: str - source: Expression + source: "Expression" arguments: List["Expression"] + type: Optional[TypeUsage] @dataclass class Operation: - id: str left: "Expression" op: Operator right: "Expression" + type: Optional[TypeUsage] @dataclass class VariableUsage: - id: str name: Identifier + type: Optional[TypeUsage] @dataclass class Expression: - id: str expression: Union[LiteralInt, FunctionCall, VariableUsage, Operation] + type: Optional[TypeUsage] @dataclass class LetStatement: - id: str variable_name: Identifier type: Optional[TypeUsage] expression: Expression @@ -100,29 +101,27 @@ Statement = Union[LetStatement, Expression] @dataclass class Block: - id: str statements: List[Statement] + type: Optional[TypeUsage] @dataclass class VariableDeclaration: - id: str name: Identifier type: TypeUsage @dataclass class Function: - id: str name: Identifier arguments: List[VariableDeclaration] block: Block return_type: TypeUsage + type: TypeUsage @dataclass class Module: - id: str functions: List[Function] @@ -196,7 +195,6 @@ boring_grammar = r""" class TreeToBoring(Transformer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.id = 0 def plus(self, p) -> Operator: return Operator.plus @@ -212,8 +210,7 @@ class TreeToBoring(Transformer): def literal_int(self, n) -> LiteralInt: (n,) = n - self.id += 1 - return LiteralInt(value=int(n), id=str(self.id)) + return LiteralInt(value=int(n), type=None) def identifier(self, i) -> Identifier: (i,) = i @@ -221,64 +218,52 @@ class TreeToBoring(Transformer): def variable_usage(self, variable) -> VariableUsage: (variable,) = variable - self.id += 1 - return VariableUsage(name=variable, id=str(self.id)) + return VariableUsage(name=variable, type=None) def function_call(self, call) -> FunctionCall: - self.id += 1 - return FunctionCall(source=call[0], arguments=call[1:], id=str(self.id)) + return FunctionCall(source=call[0], arguments=call[1:], type=None) def add_expression(self, ae) -> Operation: - self.id += 1 - return Operation(left=ae[0], op=ae[1], right=ae[2], id=str(self.id)) + return Operation(left=ae[0], op=ae[1], right=ae[2], type=None) def sub_expression(self, se) -> Operation: - self.id += 1 - return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id)) + return Operation(left=se[0], op=se[1], right=se[2], type=None) def mult_expression(self, se) -> Operation: - self.id += 1 - return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id)) + return Operation(left=se[0], op=se[1], right=se[2], type=None) def div_expression(self, se) -> Operation: - self.id += 1 - return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id)) + return Operation(left=se[0], op=se[1], right=se[2], type=None) def expression(self, exp) -> Expression: (exp,) = exp if isinstance(exp, Expression): return exp - self.id += 1 - return Expression(expression=exp, id=str(self.id)) + return Expression(expression=exp, type=None) def factor(self, factor) -> Expression: (factor,) = factor if isinstance(factor, Expression): return factor - self.id += 1 - return Expression(expression=factor, id=str(self.id)) + return Expression(expression=factor, type=None) def term(self, term) -> Expression: (term,) = term - self.id += 1 - return Expression(expression=term, id=str(self.id)) + return Expression(expression=term, type=None) def let_statement(self, let_statement) -> LetStatement: - self.id += 1 if len(let_statement) == 3: (variable_name, type_usage, expression) = let_statement return LetStatement( variable_name=variable_name, type=type_usage, expression=expression, - id=str(self.id), ) (variable_name, expression) = let_statement return LetStatement( variable_name=variable_name, type=None, expression=expression, - id=str(self.id), ) def statement(self, statement): @@ -286,23 +271,20 @@ class TreeToBoring(Transformer): return statement def block(self, block) -> Block: - self.id += 1 - return Block(statements=block, id=str(self.id)) + return Block(statements=block, type=None) def data_type(self, name) -> TypeUsage: (name,) = name - self.id += 1 return DataTypeUsage(name=name) def function_type(self, type_usage) -> TypeUsage: - self.id += 1 - return FunctionTypeUsage(arguments=type_usage, return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE))) + return FunctionTypeUsage( + arguments=type_usage, + return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)), + ) def function_type_with_return(self, type_usage) -> TypeUsage: - self.id += 1 - return FunctionTypeUsage( - arguments=type_usage[0:-1], return_type=type_usage[-1] - ) + return FunctionTypeUsage(arguments=type_usage[0:-1], return_type=type_usage[-1]) def type_usage(self, type_usage): (type_usage,) = type_usage @@ -310,27 +292,29 @@ class TreeToBoring(Transformer): def variable_declaration(self, identifier) -> VariableDeclaration: (identifier, type_usage) = identifier - self.id += 1 - return VariableDeclaration(name=identifier, type=type_usage, id=str(self.id)) + return VariableDeclaration(name=identifier, type=type_usage) def function_without_return(self, function) -> Function: - self.id += 1 return Function( - id=str(self.id), name=function[0], arguments=function[1:-1], return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)), block=function[-1], + type=FunctionTypeUsage( + arguments=[arg.type for arg in function[1:-1]], + return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)), + ), ) def function_with_return(self, function) -> Function: - self.id += 1 return Function( - id=str(self.id), name=function[0], arguments=function[1:-2], return_type=function[-2], block=function[-1], + type=FunctionTypeUsage( + arguments=[arg.type for arg in function[1:-2]], return_type=function[-2] + ), ) def function(self, function): @@ -338,8 +322,7 @@ class TreeToBoring(Transformer): return function def module(self, functions) -> Module: - self.id += 1 - return Module(id=str(self.id), functions=functions) + return Module(functions=functions) boring_parser = Lark(boring_grammar, start="module", lexer="standard") diff --git a/boring/type_checker.py b/boring/type_checker.py index 40a5070..7853ca6 100644 --- a/boring/type_checker.py +++ b/boring/type_checker.py @@ -19,6 +19,7 @@ class FunctionCallTypeComparison: to_id: str type_usage: Optional[parse.TypeUsage] + @dataclass class FunctionArgumentTypeComparison: from_id: str @@ -27,7 +28,9 @@ class FunctionArgumentTypeComparison: type_usage: Optional[parse.TypeUsage] -TypeComparison = Union[EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison] +TypeComparison = Union[ + EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison +] Environment = Dict[str, str] @@ -42,7 +45,9 @@ class TypeCheckTableBuilder: return_type=function.return_type, ) table.append( - EqualityTypeComparison(from_id=None, to_id=function.id, type_usage=type_usage) + EqualityTypeComparison( + from_id=None, to_id=function.id, type_usage=type_usage + ) ) for function in module.functions: self.with_function(env, table, function) @@ -54,10 +59,14 @@ class TypeCheckTableBuilder: for argument in function.arguments: function_env[argument.name.name] = argument.id table.append( - EqualityTypeComparison(from_id=None, to_id=argument.id, type_usage=argument.type) + EqualityTypeComparison( + from_id=None, to_id=argument.id, type_usage=argument.type + ) ) table.append( - EqualityTypeComparison(from_id=None, to_id=function.block.id, type_usage=function.return_type) + EqualityTypeComparison( + from_id=None, to_id=function.block.id, type_usage=function.return_type + ) ) self.with_block(function_env, table, function.block) @@ -71,11 +80,19 @@ class TypeCheckTableBuilder: # if parent is type, must be expression if isinstance(block.statements[-1], parse.Expression): table.append( - EqualityTypeComparison(from_id=block.statements[-1].id, to_id=block.id, type_usage=None) + EqualityTypeComparison( + from_id=block.statements[-1].id, to_id=block.id, type_usage=None + ) ) else: table.append( - EqualityTypeComparison(from_id=None, to_id=block.id, type_usage=parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))) + EqualityTypeComparison( + from_id=None, + to_id=block.id, + type_usage=parse.DataTypeUsage( + name=parse.Identifier(name=parse.UNIT_TYPE) + ), + ) ) for statement in block.statements: print(statement) @@ -86,73 +103,115 @@ class TypeCheckTableBuilder: ): if isinstance(statement, parse.LetStatement): self.with_let_statement(env, table, statement) - elif isinstance(statement, parse.Expression): # expression + elif isinstance(statement, parse.Expression): # expression self.with_expression(env, table, statement) else: assert False def with_let_statement( - self, env: Environment, table: List[TypeComparison], let_statement: parse.LetStatement + self, + env: Environment, + table: List[TypeComparison], + let_statement: parse.LetStatement, ): env[let_statement.variable_name.name] = let_statement.id table.append( - EqualityTypeComparison(from_id=let_statement.expression.id, to_id=let_statement.id, type_usage=let_statement.type) + EqualityTypeComparison( + from_id=let_statement.expression.id, + to_id=let_statement.id, + type_usage=let_statement.type, + ) ) self.with_expression(env, table, let_statement.expression) def with_expression( - self, env: Environment, table: List[TypeComparison], expression: parse.Expression + self, + env: Environment, + table: List[TypeComparison], + expression: parse.Expression, ): if isinstance(expression.expression, parse.LiteralInt): table.append( - EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) + EqualityTypeComparison( + from_id=expression.expression.id, + to_id=expression.id, + type_usage=None, + ) ) self.with_literal_int(env, table, expression.expression) elif isinstance(expression.expression, parse.FunctionCall): table.append( - EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) + EqualityTypeComparison( + from_id=expression.expression.id, + to_id=expression.id, + type_usage=None, + ) ) self.with_function_call(env, table, expression.expression) elif isinstance(expression.expression, parse.VariableUsage): table.append( - EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) + EqualityTypeComparison( + from_id=expression.expression.id, + to_id=expression.id, + type_usage=None, + ) ) self.with_variable_usage(env, table, expression.expression) elif isinstance(expression.expression, parse.Operation): table.append( - EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) + EqualityTypeComparison( + from_id=expression.expression.id, + to_id=expression.id, + type_usage=None, + ) ) self.with_operation(env, table, expression.expression) else: assert False def with_variable_usage( - self, env: Environment, table: List[TypeComparison], variable_usage: parse.VariableUsage + self, + env: Environment, + table: List[TypeComparison], + variable_usage: parse.VariableUsage, ): - print('%%%%%%%%%%%%%%%%%%%%%') + print("%%%%%%%%%%%%%%%%%%%%%") print(env[variable_usage.name.name]) print(variable_usage.id) table.append( - EqualityTypeComparison(from_id=env[variable_usage.name.name], to_id=variable_usage.id, type_usage=None) + EqualityTypeComparison( + from_id=env[variable_usage.name.name], + to_id=variable_usage.id, + type_usage=None, + ) ) def with_operation( self, env: Environment, table: List[TypeComparison], operation: parse.Operation ): table.append( - EqualityTypeComparison(from_id=operation.left.id, to_id=operation.id, type_usage=None) + EqualityTypeComparison( + from_id=operation.left.id, to_id=operation.id, type_usage=None + ) ) table.append( - EqualityTypeComparison(from_id=operation.right.id, to_id=operation.id, type_usage=None) + EqualityTypeComparison( + from_id=operation.right.id, to_id=operation.id, type_usage=None + ) ) self.with_expression(env, table, operation.left) self.with_expression(env, table, operation.right) def with_function_call( - self, env: Environment, table: List[TypeComparison], function_call: parse.FunctionCall + self, + env: Environment, + table: List[TypeComparison], + function_call: parse.FunctionCall, ): table.append( - EqualityTypeComparison(to_id=function_call.id, from_id=function_call.source.id, type_usage=None) + EqualityTypeComparison( + to_id=function_call.id, from_id=function_call.source.id, type_usage=None + ) ) self.with_expression(env, table, function_call.source) @@ -163,14 +222,20 @@ class TypeCheckTableBuilder: # FunctionArgumentTypeComparison self.with_expression(env, table, argument) - def with_literal_int( - self, env: Environment, table: List[TypeComparison], literal_int: parse.LiteralInt + self, + env: Environment, + table: List[TypeComparison], + literal_int: parse.LiteralInt, ): table.append( - EqualityTypeComparison(from_id=None, to_id=literal_int.id, type_usage=parse.DataTypeUsage( - name=parse.Identifier(name="u32"), - )) + EqualityTypeComparison( + from_id=None, + to_id=literal_int.id, + type_usage=parse.DataTypeUsage( + name=parse.Identifier(name="u32"), + ), + ) ) @@ -201,7 +266,9 @@ def check_types(table: List[TypeComparison]): other.type_usage = entry.type_usage found = True elif isinstance(other, FunctionCallTypeComparison): - assert isinstance(entry.type_usage, parse.FunctionTypeUsage), 'non function called' + assert isinstance( + entry.type_usage, parse.FunctionTypeUsage + ), "non function called" other.type_usage = entry.type_usage.return_type found = True elif other.type_usage is not None and entry.type_usage is None: @@ -209,20 +276,16 @@ def check_types(table: List[TypeComparison]): entry.type_usage = other.type_usage found = True elif isinstance(other, FunctionCallTypeComparison): - pass # can't reverse a function + pass # can't reverse a function else: if isinstance(other, EqualityTypeComparison): assert other.type_usage == entry.type_usage elif isinstance(other, FunctionCallTypeComparison): - assert isinstance(entry.type_usage, parse.FunctionTypeUsage), 'non function called' + assert isinstance( + entry.type_usage, parse.FunctionTypeUsage + ), "non function called" assert other.type_usage == entry.type_usage.return_type - - - - - - # if other.to_id == entry.from_id: # # let a = || {4} other # # let b = a() entry diff --git a/boring/type_checking.py b/boring/type_checking.py new file mode 100644 index 0000000..f139fe3 --- /dev/null +++ b/boring/type_checking.py @@ -0,0 +1,205 @@ +from dataclasses import dataclass +from typing import List, Dict, Optional, Union + + +from boring import parse + + +Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration] +Environment = Dict[str, Identified] + + +class TypeChecker: + def with_module(self, env: Environment, module: parse.Module) -> bool: + for function in module.functions: + env[function.name.name] = function + found = False + for function in module.functions: + if self.with_function(env, function): + found = True + return found + + def with_function(self, env: Environment, function: parse.Function) -> bool: + function_env = env.copy() + for argument in function.arguments: + function_env[argument.name.name] = argument + assert isinstance(function.type, parse.FunctionTypeUsage) + function.block.type = function.type.return_type + return self.with_block(function_env, function.block) + + # Skip variable VariableDeclaration + + def with_block(self, env: Environment, block: parse.Block) -> bool: + block_env = env.copy() + # if parent is void, must be statement + # if parent is type, must be expression + found = False + final = block.statements[-1] + if isinstance(final, parse.LetStatement): + if block.type is None: + found = True + block.type = parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE)) + else: + assert block.type == parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE)) + elif isinstance(final, parse.Expression): + if block.type is None: + if final.type is not None: + found = True + block.type = final.type + else: + if final.type is None: + found = True + final.type = block.type + else: + assert final.type == block.type + + for statement in block.statements: + if self.with_statement(block_env, statement): + found = True + return found + + def with_statement(self, env: Environment, statement: parse.Statement) -> bool: + if isinstance(statement, parse.LetStatement): + return self.with_let_statement(env, statement) + elif isinstance(statement, parse.Expression): # expression + return self.with_expression(env, statement) + else: + assert False + + def with_let_statement( + self, env: Environment, let_statement: parse.LetStatement + ) -> bool: + found = False + env[let_statement.variable_name.name] = let_statement + if let_statement.type is None: + if let_statement.expression.type is not None: + let_statement.type = let_statement.expression.type + found = True + else: + if let_statement.expression.type is None: + let_statement.expression.type = let_statement.type + found = True + else: + assert let_statement.expression.type == let_statement.type + if self.with_expression(env, let_statement.expression): + found = True + return found + + + def with_expression(self, env: Environment, expression: parse.Expression) -> bool: + subexpression = expression.expression + found = False + # generic to all types + if expression.type is None: + if subexpression.type is not None: + expression.type = subexpression.type + found = True + else: + if subexpression.type is None: + subexpression.type = expression.type + found = True + else: + assert subexpression.type == expression.type + + if isinstance(subexpression, parse.LiteralInt): + if self.with_literal_int(env, subexpression): + found = True + return found + if isinstance(subexpression, parse.FunctionCall): + if self.with_function_call(env, subexpression): + found = True + return found + if isinstance(subexpression, parse.VariableUsage): + if self.with_variable_usage(env, subexpression): + found = True + return found + if isinstance(subexpression, parse.Operation): + if self.with_operation(env, subexpression): + found = True + return found + assert False + + def with_variable_usage( + self, env: Environment, variable_usage: parse.VariableUsage + ) -> bool: + found = False + variable = env[variable_usage.name.name] + if variable_usage.type is None: + if variable.type is not None: + variable_usage.type = variable.type + found = True + else: + if variable.type is None: + # print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + # print(f"{variable.name} {variable.type}") + variable.type = variable_usage.type + found = True + else: + assert variable.type == variable_usage.type + return found + + def with_operation(self, env: Environment, operation: parse.Operation) -> bool: + found = False + if operation.type is None: + if operation.left.type is not None: + operation.type = operation.left.type + found = True + else: + if operation.left.type is None: + operation.left.type = operation.type + found = True + else: + assert operation.left.type == operation.type + if operation.type is None: + if operation.right.type is not None: + operation.type = operation.right.type + found = True + else: + if operation.right.type is None: + operation.right.type = operation.type + found = True + else: + assert operation.right.type == operation.type + if self.with_expression(env, operation.left): + found = True + if self.with_expression(env, operation.right): + found = True + return found + + def with_function_call( + self, env: Environment, function_call: parse.FunctionCall + ) -> bool: + found = False + if function_call.type is None: + if function_call.source.type is not None: + assert isinstance(function_call.source.type, parse.FunctionTypeUsage) + found = True + function_call.type = function_call.source.type.return_type + else: + if function_call.source.type is not None: + assert isinstance(function_call.source.type, parse.FunctionTypeUsage) + assert function_call.type == function_call.source.type.return_type + if self.with_expression(env, function_call.source): + found = True + + if function_call.source.type is not None: + assert isinstance(function_call.source.type, parse.FunctionTypeUsage) + assert len(function_call.arguments) == len(function_call.source.type.arguments) + for (argument, type_argument) in zip(function_call.arguments, function_call.source.type.arguments): + if argument.type is None: + argument.type = type_argument + found = True + else: + assert argument.type == type_argument + + for argument in function_call.arguments: + if self.with_expression(env, argument): + found = True + return found + + + def with_literal_int(self, env: Environment, literal_int: parse.LiteralInt) -> bool: + ints = [parse.DataTypeUsage(name=parse.Identifier(name=name)) for name in ["Int8", "Int16", "Int32", "Int64", "Int128", "u32"]] + if literal_int.type is not None: + assert literal_int.type in ints, f"{literal_int.type}" + return False diff --git a/boring/typing.py b/boring/typing.py deleted file mode 100644 index 2688789..0000000 --- a/boring/typing.py +++ /dev/null @@ -1,64 +0,0 @@ -from dataclasses import dataclass -from typing import List, Dict, Optional, Union - - -from boring import parse - - -Identified = Union[parse.LetStatement, parse.Function] -Environment = Dict[str, Identified] - - -class TypeChecker: - def with_module( - self, env: Environment, table: List[TypeComparison], module: parse.Module - ) -> bool: - pass - - def with_function( - self, env: Environment, table: List[TypeComparison], function: parse.Function - ) -> bool: - pass - - # Skip variable VariableDeclaration - - def with_block( - self, env: Environment, table: List[TypeComparison], block: parse.Block - ) -> bool: - pass - - def with_statement( - self, env: Environment, table: List[TypeComparison], statement: parse.Statement - ) -> bool: - pass - - def with_let_statement( - self, env: Environment, table: List[TypeComparison], let_statement: parse.LetStatement - ) -> bool: - pass - - def with_expression( - self, env: Environment, table: List[TypeComparison], expression: parse.Expression - ) -> bool: - pass - - def with_variable_usage( - self, env: Environment, table: List[TypeComparison], variable_usage: parse.VariableUsage - ) -> bool: - pass - - def with_operation( - self, env: Environment, table: List[TypeComparison], operation: parse.Operation - ) -> bool: - pass - - def with_function_call( - self, env: Environment, table: List[TypeComparison], function_call: parse.FunctionCall - ) -> bool: - pass - - - def with_literal_int( - self, env: Environment, table: List[TypeComparison], literal_int: parse.LiteralInt - ) -> bool: - pass