From f5fc6643fb5f68c3cb608db598a51d66d17f4646 Mon Sep 17 00:00:00 2001 From: Andrew Segavac Date: Fri, 28 May 2021 23:57:07 -0600 Subject: [PATCH] got type checking working for real --- boring/main.py | 2 +- boring/parse.py | 72 ++++----- boring/type_checker.py | 317 ---------------------------------------- boring/type_checking.py | 239 +++++++++++++++--------------- examples/math/main.bl | 6 +- 5 files changed, 161 insertions(+), 475 deletions(-) delete mode 100644 boring/type_checker.py diff --git a/boring/main.py b/boring/main.py index c5db87f..dae2295 100644 --- a/boring/main.py +++ b/boring/main.py @@ -11,7 +11,7 @@ if __name__ == "__main__": # pretty_print(result) type_checker = TypeChecker() while type_checker.with_module({}, result): - print('loop') + print("loop") # type_checker.with_module({}, result) pretty_print(result) # tctb = TypeCheckTableBuilder() diff --git a/boring/parse.py b/boring/parse.py index d555777..a7bb7cc 100644 --- a/boring/parse.py +++ b/boring/parse.py @@ -28,11 +28,6 @@ def pretty_print(clas, indent=0): UNIT_TYPE = "()" -@dataclass -class Identifier: - name: str - - @dataclass class FunctionTypeUsage: arguments: List[ @@ -43,10 +38,15 @@ class FunctionTypeUsage: @dataclass class DataTypeUsage: - name: Identifier + name: str -TypeUsage = Union[FunctionTypeUsage, DataTypeUsage] +@dataclass +class UnknownTypeUsage: + pass + + +TypeUsage = Union[FunctionTypeUsage, DataTypeUsage, UnknownTypeUsage] class Operator(enum.Enum): @@ -59,14 +59,14 @@ class Operator(enum.Enum): @dataclass class LiteralInt: value: int - type: Optional[TypeUsage] + type: TypeUsage @dataclass class FunctionCall: source: "Expression" arguments: List["Expression"] - type: Optional[TypeUsage] + type: TypeUsage @dataclass @@ -74,25 +74,25 @@ class Operation: left: "Expression" op: Operator right: "Expression" - type: Optional[TypeUsage] + type: TypeUsage @dataclass class VariableUsage: - name: Identifier - type: Optional[TypeUsage] + name: str + type: TypeUsage @dataclass class Expression: expression: Union[LiteralInt, FunctionCall, VariableUsage, Operation] - type: Optional[TypeUsage] + type: TypeUsage @dataclass class LetStatement: - variable_name: Identifier - type: Optional[TypeUsage] + variable_name: str + type: TypeUsage expression: Expression @@ -102,18 +102,18 @@ Statement = Union[LetStatement, Expression] @dataclass class Block: statements: List[Statement] - type: Optional[TypeUsage] + type: TypeUsage @dataclass class VariableDeclaration: - name: Identifier + name: str type: TypeUsage @dataclass class Function: - name: Identifier + name: str arguments: List[VariableDeclaration] block: Block return_type: TypeUsage @@ -191,6 +191,8 @@ boring_grammar = r""" %ignore WS """ +next_sub_id = 0 + class TreeToBoring(Transformer): def __init__(self, *args, **kwargs): @@ -210,46 +212,46 @@ class TreeToBoring(Transformer): def literal_int(self, n) -> LiteralInt: (n,) = n - return LiteralInt(value=int(n), type=None) + return LiteralInt(value=int(n), type=UnknownTypeUsage()) - def identifier(self, i) -> Identifier: + def identifier(self, i) -> str: (i,) = i - return Identifier(name=str(i)) + return str(i) def variable_usage(self, variable) -> VariableUsage: (variable,) = variable - return VariableUsage(name=variable, type=None) + return VariableUsage(name=variable, type=UnknownTypeUsage()) def function_call(self, call) -> FunctionCall: - return FunctionCall(source=call[0], arguments=call[1:], type=None) + return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage()) def add_expression(self, ae) -> Operation: - return Operation(left=ae[0], op=ae[1], right=ae[2], type=None) + return Operation(left=ae[0], op=ae[1], right=ae[2], type=UnknownTypeUsage()) def sub_expression(self, se) -> Operation: - return Operation(left=se[0], op=se[1], right=se[2], type=None) + return Operation(left=se[0], op=se[1], right=se[2], type=UnknownTypeUsage()) def mult_expression(self, se) -> Operation: - return Operation(left=se[0], op=se[1], right=se[2], type=None) + return Operation(left=se[0], op=se[1], right=se[2], type=UnknownTypeUsage()) def div_expression(self, se) -> Operation: - return Operation(left=se[0], op=se[1], right=se[2], type=None) + return Operation(left=se[0], op=se[1], right=se[2], type=UnknownTypeUsage()) def expression(self, exp) -> Expression: (exp,) = exp if isinstance(exp, Expression): return exp - return Expression(expression=exp, type=None) + return Expression(expression=exp, type=UnknownTypeUsage()) def factor(self, factor) -> Expression: (factor,) = factor if isinstance(factor, Expression): return factor - return Expression(expression=factor, type=None) + return Expression(expression=factor, type=UnknownTypeUsage()) def term(self, term) -> Expression: (term,) = term - return Expression(expression=term, type=None) + return Expression(expression=term, type=UnknownTypeUsage()) def let_statement(self, let_statement) -> LetStatement: if len(let_statement) == 3: @@ -262,7 +264,7 @@ class TreeToBoring(Transformer): (variable_name, expression) = let_statement return LetStatement( variable_name=variable_name, - type=None, + type=UnknownTypeUsage(), expression=expression, ) @@ -271,7 +273,7 @@ class TreeToBoring(Transformer): return statement def block(self, block) -> Block: - return Block(statements=block, type=None) + return Block(statements=block, type=UnknownTypeUsage()) def data_type(self, name) -> TypeUsage: (name,) = name @@ -280,7 +282,7 @@ class TreeToBoring(Transformer): def function_type(self, type_usage) -> TypeUsage: return FunctionTypeUsage( arguments=type_usage, - return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)), + return_type=DataTypeUsage(name=UNIT_TYPE), ) def function_type_with_return(self, type_usage) -> TypeUsage: @@ -298,11 +300,11 @@ class TreeToBoring(Transformer): return Function( name=function[0], arguments=function[1:-1], - return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)), + return_type=DataTypeUsage(name=UNIT_TYPE), block=function[-1], type=FunctionTypeUsage( arguments=[arg.type for arg in function[1:-1]], - return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)), + return_type=DataTypeUsage(name=UNIT_TYPE), ), ) diff --git a/boring/type_checker.py b/boring/type_checker.py deleted file mode 100644 index 7853ca6..0000000 --- a/boring/type_checker.py +++ /dev/null @@ -1,317 +0,0 @@ -from dataclasses import dataclass -from typing import List, Dict, Optional, Union - - -from boring import parse - - -@dataclass -class EqualityTypeComparison: - from_id: Optional[str] - to_id: str - type_usage: Optional[parse.TypeUsage] - # constraints: List[Constraint] - - -@dataclass -class FunctionCallTypeComparison: - from_id: str - to_id: str - type_usage: Optional[parse.TypeUsage] - - -@dataclass -class FunctionArgumentTypeComparison: - from_id: str - to_id: str - argument_id: int - type_usage: Optional[parse.TypeUsage] - - -TypeComparison = Union[ - EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison -] -Environment = Dict[str, str] - - -class TypeCheckTableBuilder: - def with_module( - self, env: Environment, table: List[TypeComparison], module: parse.Module - ): - for function in module.functions: - env[function.name.name] = function.id - type_usage = parse.FunctionTypeUsage( - arguments=[arg.type for arg in function.arguments], - return_type=function.return_type, - ) - table.append( - EqualityTypeComparison( - from_id=None, to_id=function.id, type_usage=type_usage - ) - ) - for function in module.functions: - self.with_function(env, table, function) - - def with_function( - self, env: Environment, table: List[TypeComparison], function: parse.Function - ): - function_env = env.copy() - for argument in function.arguments: - function_env[argument.name.name] = argument.id - table.append( - EqualityTypeComparison( - from_id=None, to_id=argument.id, type_usage=argument.type - ) - ) - table.append( - EqualityTypeComparison( - from_id=None, to_id=function.block.id, type_usage=function.return_type - ) - ) - self.with_block(function_env, table, function.block) - - # Skip variable VariableDeclaration - - def with_block( - self, env: Environment, table: List[TypeComparison], block: parse.Block - ): - block_env = env.copy() - # if parent is void, must be statement - # if parent is type, must be expression - if isinstance(block.statements[-1], parse.Expression): - table.append( - EqualityTypeComparison( - from_id=block.statements[-1].id, to_id=block.id, type_usage=None - ) - ) - else: - table.append( - EqualityTypeComparison( - from_id=None, - to_id=block.id, - type_usage=parse.DataTypeUsage( - name=parse.Identifier(name=parse.UNIT_TYPE) - ), - ) - ) - for statement in block.statements: - print(statement) - self.with_statement(block_env, table, statement) - - def with_statement( - self, env: Environment, table: List[TypeComparison], statement: parse.Statement - ): - if isinstance(statement, parse.LetStatement): - self.with_let_statement(env, table, statement) - elif isinstance(statement, parse.Expression): # expression - self.with_expression(env, table, statement) - else: - assert False - - def with_let_statement( - self, - env: Environment, - table: List[TypeComparison], - let_statement: parse.LetStatement, - ): - env[let_statement.variable_name.name] = let_statement.id - table.append( - EqualityTypeComparison( - from_id=let_statement.expression.id, - to_id=let_statement.id, - type_usage=let_statement.type, - ) - ) - self.with_expression(env, table, let_statement.expression) - - def with_expression( - self, - env: Environment, - table: List[TypeComparison], - expression: parse.Expression, - ): - if isinstance(expression.expression, parse.LiteralInt): - table.append( - EqualityTypeComparison( - from_id=expression.expression.id, - to_id=expression.id, - type_usage=None, - ) - ) - self.with_literal_int(env, table, expression.expression) - elif isinstance(expression.expression, parse.FunctionCall): - table.append( - EqualityTypeComparison( - from_id=expression.expression.id, - to_id=expression.id, - type_usage=None, - ) - ) - self.with_function_call(env, table, expression.expression) - elif isinstance(expression.expression, parse.VariableUsage): - table.append( - EqualityTypeComparison( - from_id=expression.expression.id, - to_id=expression.id, - type_usage=None, - ) - ) - self.with_variable_usage(env, table, expression.expression) - elif isinstance(expression.expression, parse.Operation): - table.append( - EqualityTypeComparison( - from_id=expression.expression.id, - to_id=expression.id, - type_usage=None, - ) - ) - self.with_operation(env, table, expression.expression) - else: - assert False - - def with_variable_usage( - self, - env: Environment, - table: List[TypeComparison], - variable_usage: parse.VariableUsage, - ): - print("%%%%%%%%%%%%%%%%%%%%%") - print(env[variable_usage.name.name]) - print(variable_usage.id) - table.append( - EqualityTypeComparison( - from_id=env[variable_usage.name.name], - to_id=variable_usage.id, - type_usage=None, - ) - ) - - def with_operation( - self, env: Environment, table: List[TypeComparison], operation: parse.Operation - ): - table.append( - EqualityTypeComparison( - from_id=operation.left.id, to_id=operation.id, type_usage=None - ) - ) - table.append( - EqualityTypeComparison( - from_id=operation.right.id, to_id=operation.id, type_usage=None - ) - ) - self.with_expression(env, table, operation.left) - self.with_expression(env, table, operation.right) - - def with_function_call( - self, - env: Environment, - table: List[TypeComparison], - function_call: parse.FunctionCall, - ): - table.append( - EqualityTypeComparison( - to_id=function_call.id, from_id=function_call.source.id, type_usage=None - ) - ) - self.with_expression(env, table, function_call.source) - - for i, argument in enumerate(function_call.arguments): - # table.append( - # FunctionArgumentTypeComparison(from_id=env[function_call.name.name], to_id=function_call.id, type_usage=None) - # ) - # FunctionArgumentTypeComparison - self.with_expression(env, table, argument) - - def with_literal_int( - self, - env: Environment, - table: List[TypeComparison], - literal_int: parse.LiteralInt, - ): - table.append( - EqualityTypeComparison( - from_id=None, - to_id=literal_int.id, - type_usage=parse.DataTypeUsage( - name=parse.Identifier(name="u32"), - ), - ) - ) - - -def check_types(table: List[TypeComparison]): - found = True - while found: - found = False - for entry in table: - for other in table: - if other.to_id == entry.to_id: - if other.type_usage is None and entry.type_usage is None: - pass - elif other.type_usage is None and entry.type_usage is not None: - other.type_usage = entry.type_usage - found = True - elif other.type_usage is not None and entry.type_usage is None: - entry.type_usage = other.type_usage - found = True - else: - assert entry.type_usage == other.type_usage - if other.from_id == entry.to_id: - # let a = || {4} entry - # let b = a() other - if other.type_usage is None and entry.type_usage is None: - pass - elif other.type_usage is None and entry.type_usage is not None: - if isinstance(other, EqualityTypeComparison): - other.type_usage = entry.type_usage - found = True - elif isinstance(other, FunctionCallTypeComparison): - assert isinstance( - entry.type_usage, parse.FunctionTypeUsage - ), "non function called" - other.type_usage = entry.type_usage.return_type - found = True - elif other.type_usage is not None and entry.type_usage is None: - if isinstance(other, EqualityTypeComparison): - entry.type_usage = other.type_usage - found = True - elif isinstance(other, FunctionCallTypeComparison): - pass # can't reverse a function - else: - if isinstance(other, EqualityTypeComparison): - assert other.type_usage == entry.type_usage - elif isinstance(other, FunctionCallTypeComparison): - assert isinstance( - entry.type_usage, parse.FunctionTypeUsage - ), "non function called" - assert other.type_usage == entry.type_usage.return_type - - # if other.to_id == entry.from_id: - # # let a = || {4} other - # # let b = a() entry - # if other.type_usage is None and entry.type_usage is None: - # pass - # elif other.type_usage is None and entry.type_usage is not None: - # if isinstance(entry, EqualityTypeComparison): - # other.type_usage = entry.type_usage - # found = True - # elif isinstance(entry, FunctionCallTypeComparison): - # pass # can't reverse a function - # elif other.type_usage is not None and entry.type_usage is None: - # if isinstance(entry, EqualityTypeComparison): - # entry.type_usage = other.type_usage - # found = True - # elif isinstance(entry, FunctionCallTypeComparison): - # entry.type_usage = other.type_usage.return_type - # found = True - # if other.from_id == entry.from_id and entry.from_id is not None: - # if other.type_usage is None and entry.type_usage is None: - # pass - # elif other.type_usage is None and entry.type_usage is not None: - # other.type_usage = entry.type_usage - # found = True - # elif other.type_usage is not None and entry.type_usage is None: - # entry.type_usage = other.type_usage - # found = True - # else: - # assert entry.type_usage == other.type_usage, f"{entry.from_id} {other.from_id} {entry.type_usage} == {other.type_usage}" diff --git a/boring/type_checking.py b/boring/type_checking.py index f139fe3..ad8b554 100644 --- a/boring/type_checking.py +++ b/boring/type_checking.py @@ -9,10 +9,53 @@ Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration Environment = Dict[str, Identified] +def unify(first, second) -> bool: + result, changed = type_compare(first.type, second.type) + first.type = result + second.type = result + return changed + + +def type_compare(first, second) -> (parse.TypeUsage, bool): + print(first, second) + if isinstance(first, parse.UnknownTypeUsage): + if not isinstance(second, parse.UnknownTypeUsage): + return second, True + else: + return parse.UnknownTypeUsage(), False + else: + if isinstance(second, parse.UnknownTypeUsage): + return first, True + else: + if isinstance(first, parse.DataTypeUsage) and isinstance( + second, parse.DataTypeUsage + ): + assert second == first + return first, False + elif isinstance(first, parse.FunctionTypeUsage) and isinstance( + second, parse.FunctionTypeUsage + ): + return_type, changed = type_compare( + first.return_type, second.return_type + ) + arguments = [] + assert len(first.arguments) == len(second.arguments) + for first_arg, second_arg in zip(first.arguments, second.arguments): + argument_type, argument_changed = type_compare( + first_arg, second_arg + ) + arguments.append(argument_type) + if argument_changed: + changed = True + return parse.FunctionTypeUsage(arguments, return_type), changed + else: + assert False, f"mismatched types {first}, {second}" + + class TypeChecker: def with_module(self, env: Environment, module: parse.Module) -> bool: for function in module.functions: - env[function.name.name] = function + env[function.name] = function found = False for function in module.functions: if self.with_function(env, function): @@ -22,10 +65,15 @@ class TypeChecker: def with_function(self, env: Environment, function: parse.Function) -> bool: function_env = env.copy() for argument in function.arguments: - function_env[argument.name.name] = argument + function_env[argument.name] = argument assert isinstance(function.type, parse.FunctionTypeUsage) - function.block.type = function.type.return_type - return self.with_block(function_env, function.block) + + type, changed = type_compare(function.block.type, function.type.return_type) + function.block.type = type + function.type.return_type = type + if self.with_block(function_env, function.block): + changed = True + return changed # Skip variable VariableDeclaration @@ -33,30 +81,26 @@ class TypeChecker: block_env = env.copy() # if parent is void, must be statement # if parent is type, must be expression - found = False + changed = False final = block.statements[-1] if isinstance(final, parse.LetStatement): - if block.type is None: + if isinstance(block.type, parse.UnknownTypeUsage): found = True - block.type = parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE)) + block.type = parse.DataTypeUsage( + name=parse.Identifier(name=parse.UNIT_TYPE) + ) else: - assert block.type == parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE)) + assert block.type == parse.DataTypeUsage( + name=parse.Identifier(name=parse.UNIT_TYPE) + ) elif isinstance(final, parse.Expression): - if block.type is None: - if final.type is not None: - found = True - block.type = final.type - else: - if final.type is None: - found = True - final.type = block.type - else: - assert final.type == block.type + if unify(final, block): + changed = True for statement in block.statements: if self.with_statement(block_env, statement): - found = True - return found + changed = True + return changed def with_statement(self, env: Environment, statement: parse.Statement) -> bool: if isinstance(statement, parse.LetStatement): @@ -70,136 +114,93 @@ class TypeChecker: self, env: Environment, let_statement: parse.LetStatement ) -> bool: found = False - env[let_statement.variable_name.name] = let_statement - if let_statement.type is None: - if let_statement.expression.type is not None: - let_statement.type = let_statement.expression.type - found = True - else: - if let_statement.expression.type is None: - let_statement.expression.type = let_statement.type - found = True - else: - assert let_statement.expression.type == let_statement.type + env[let_statement.variable_name] = let_statement + changed = unify(let_statement, let_statement.expression) if self.with_expression(env, let_statement.expression): - found = True - return found - + changed = True + return changed def with_expression(self, env: Environment, expression: parse.Expression) -> bool: subexpression = expression.expression - found = False - # generic to all types - if expression.type is None: - if subexpression.type is not None: - expression.type = subexpression.type - found = True - else: - if subexpression.type is None: - subexpression.type = expression.type - found = True - else: - assert subexpression.type == expression.type + changed = unify(subexpression, expression) if isinstance(subexpression, parse.LiteralInt): + print(f"fooooo {expression.type}, {subexpression.type}") if self.with_literal_int(env, subexpression): - found = True - return found + changed = True + return changed if isinstance(subexpression, parse.FunctionCall): if self.with_function_call(env, subexpression): - found = True - return found + changed = True + return changed if isinstance(subexpression, parse.VariableUsage): if self.with_variable_usage(env, subexpression): - found = True - return found + changed = True + return changed if isinstance(subexpression, parse.Operation): if self.with_operation(env, subexpression): - found = True - return found + changed = True + return changed assert False def with_variable_usage( self, env: Environment, variable_usage: parse.VariableUsage ) -> bool: - found = False - variable = env[variable_usage.name.name] - if variable_usage.type is None: - if variable.type is not None: - variable_usage.type = variable.type - found = True - else: - if variable.type is None: - # print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') - # print(f"{variable.name} {variable.type}") - variable.type = variable_usage.type - found = True - else: - assert variable.type == variable_usage.type - return found + return unify(variable_usage, env[variable_usage.name]) def with_operation(self, env: Environment, operation: parse.Operation) -> bool: - found = False - if operation.type is None: - if operation.left.type is not None: - operation.type = operation.left.type - found = True - else: - if operation.left.type is None: - operation.left.type = operation.type - found = True - else: - assert operation.left.type == operation.type - if operation.type is None: - if operation.right.type is not None: - operation.type = operation.right.type - found = True - else: - if operation.right.type is None: - operation.right.type = operation.type - found = True - else: - assert operation.right.type == operation.type + changed = False + if unify(operation, operation.left): + changed = True + if unify(operation, operation.right): + changed = True if self.with_expression(env, operation.left): - found = True + changed = True if self.with_expression(env, operation.right): - found = True - return found + changed = True + return changed def with_function_call( self, env: Environment, function_call: parse.FunctionCall ) -> bool: - found = False - if function_call.type is None: - if function_call.source.type is not None: - assert isinstance(function_call.source.type, parse.FunctionTypeUsage) - found = True - function_call.type = function_call.source.type.return_type - else: - if function_call.source.type is not None: - assert isinstance(function_call.source.type, parse.FunctionTypeUsage) - assert function_call.type == function_call.source.type.return_type + changed = False + if isinstance(function_call.source.type, parse.UnknownTypeUsage): + function_call.source.type = parse.FunctionTypeUsage( + arguments=[parse.UnknownTypeUsage()] * len(function_call.arguments), + return_type=parse.UnknownTypeUsage(), + ) + changed = True if self.with_expression(env, function_call.source): - found = True - - if function_call.source.type is not None: - assert isinstance(function_call.source.type, parse.FunctionTypeUsage) - assert len(function_call.arguments) == len(function_call.source.type.arguments) - for (argument, type_argument) in zip(function_call.arguments, function_call.source.type.arguments): - if argument.type is None: - argument.type = type_argument - found = True - else: - assert argument.type == type_argument - + changed = True for argument in function_call.arguments: if self.with_expression(env, argument): - found = True - return found + changed = True + return_type, return_changed = type_compare( + function_call.type, function_call.source.type.return_type + ) + function_call.type = return_type + function_call.source.type.return_type = return_type + if return_changed: + changed = True + + for argument, argument_type in zip( + function_call.arguments, function_call.source.type.arguments + ): + argument_out_type, argument_changed = type_compare( + argument.type, function_call.source.type.return_type + ) + argument.type = argument_out_type + function_call.source.type.return_type = argument_out_type + if argument_changed: + changed = True + return changed def with_literal_int(self, env: Environment, literal_int: parse.LiteralInt) -> bool: - ints = [parse.DataTypeUsage(name=parse.Identifier(name=name)) for name in ["Int8", "Int16", "Int32", "Int64", "Int128", "u32"]] - if literal_int.type is not None: + ints = [ + parse.DataTypeUsage(name=name) + for name in ["I8", "I16", "I32", "I64", "I128"] + ] + if not isinstance(literal_int.type, parse.UnknownTypeUsage): assert literal_int.type in ints, f"{literal_int.type}" return False diff --git a/examples/math/main.bl b/examples/math/main.bl index fe6edc6..56c7b22 100644 --- a/examples/math/main.bl +++ b/examples/math/main.bl @@ -1,12 +1,12 @@ -fn add(a: u32, b: u32): u32 { +fn add(a: I32, b: I32): I32 { let foo = 4; a + b + foo } -fn subtract(a: u32, b: u32): u32 { +fn subtract(a: I32, b: I32): I32 { a - b } -fn main(): u32 { +fn main(): I32 { add(4, subtract(5, 2)) }