From b82566f7101f0048368fc014d345164a090ffa7d Mon Sep 17 00:00:00 2001 From: Andrew Segavac Date: Sun, 30 May 2021 09:57:41 -0600 Subject: [PATCH] added return statement --- boring/main.py | 4 +- boring/parse.py | 16 +++- boring/type_checking.py | 166 +++++++++++++++++++++++++--------------- examples/math/main.bl | 14 +++- 4 files changed, 135 insertions(+), 65 deletions(-) diff --git a/boring/main.py b/boring/main.py index c480fff..cbeba8e 100644 --- a/boring/main.py +++ b/boring/main.py @@ -1,7 +1,7 @@ import sys from typing import List from boring.parse import boring_parser, TreeToBoring, pretty_print -from boring.type_checking import TypeChecker +from boring.type_checking import TypeChecker, Context from boring import typedefs if __name__ == "__main__": @@ -11,7 +11,7 @@ if __name__ == "__main__": result = TreeToBoring().transform(tree) # pretty_print(result) type_checker = TypeChecker() - while type_checker.with_module({}, typedefs.builtins, result): + while type_checker.with_module(Context({}, typedefs.builtins, None), result): print("loop") # type_checker.with_module({}, result) pretty_print(result) diff --git a/boring/parse.py b/boring/parse.py index ac20bf2..7c9aa50 100644 --- a/boring/parse.py +++ b/boring/parse.py @@ -26,6 +26,7 @@ def pretty_print(clas, indent=0): UNIT_TYPE = "()" +NEVER_TYPE = "!" @dataclass @@ -89,9 +90,15 @@ class VariableUsage: type: TypeUsage +@dataclass +class ReturnStatement: + source: "Expression" + type: TypeUsage + + @dataclass class Expression: - expression: Union[LiteralInt, LiteralFloat, FunctionCall, "Block", VariableUsage, Operation] + expression: Union[LiteralInt, LiteralFloat, FunctionCall, "Block", ReturnStatement, VariableUsage, Operation] type: TypeUsage @@ -150,6 +157,8 @@ boring_grammar = r""" variable_usage : identifier + return_statement : "return" expression ";" + expression : add_expression | sub_expression | factor @@ -169,6 +178,7 @@ boring_grammar = r""" | "let" identifier ":" type_usage "=" expression ";" statement : let_statement + | return_statement | expression block : "{" (statement)* "}" @@ -238,6 +248,10 @@ class TreeToBoring(Transformer): (variable,) = variable return VariableUsage(name=variable, type=UnknownTypeUsage()) + def return_statement(self, return_expression) -> ReturnStatement: + (return_expression,) = return_expression + return ReturnStatement(source=return_expression, type=DataTypeUsage(name=NEVER_TYPE)) + def function_call(self, call) -> FunctionCall: return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage()) diff --git a/boring/type_checking.py b/boring/type_checking.py index 98c319f..6ed5237 100644 --- a/boring/type_checking.py +++ b/boring/type_checking.py @@ -9,15 +9,24 @@ Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration Environment = Dict[str, Identified] TypeEnvironment = Dict[str, typedefs.TypeDef] +@dataclass +class Context: + environment: Environment + type_environment: TypeEnvironment + current_function: Optional[parse.Function] -def unify(type_env: TypeEnvironment, first, second) -> bool: - result, changed = type_compare(type_env, first.type, second.type) + def copy(self): + return Context(self.environment.copy(), self.type_environment.copy(), self.current_function) + + +def unify(ctx: Context, first, second) -> bool: + result, changed = type_compare(ctx, first.type, second.type) first.type = result second.type = result return changed -def type_compare(type_env: TypeEnvironment, first, second) -> (parse.TypeUsage, bool): +def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool): print(first, second) if isinstance(first, parse.UnknownTypeUsage): if not isinstance(second, parse.UnknownTypeUsage): @@ -32,20 +41,20 @@ def type_compare(type_env: TypeEnvironment, first, second) -> (parse.TypeUsage, second, parse.DataTypeUsage ): assert second == first - assert first.name in type_env - assert second.name in type_env + assert first.name in ctx.type_environment + assert second.name in ctx.type_environment return first, False elif isinstance(first, parse.FunctionTypeUsage) and isinstance( second, parse.FunctionTypeUsage ): return_type, changed = type_compare( - type_env, first.return_type, second.return_type + ctx, first.return_type, second.return_type ) arguments = [] assert len(first.arguments) == len(second.arguments) for first_arg, second_arg in zip(first.arguments, second.arguments): argument_type, argument_changed = type_compare( - type_env, first_arg, second_arg + ctx, first_arg, second_arg ) arguments.append(argument_type) if argument_changed: @@ -56,122 +65,157 @@ def type_compare(type_env: TypeEnvironment, first, second) -> (parse.TypeUsage, class TypeChecker: - def with_module(self, env: Environment, type_env: TypeEnvironment, module: parse.Module) -> bool: + def with_module(self, ctx: Context, module: parse.Module) -> bool: for function in module.functions: - env[function.name] = function - found = False + ctx.environment[function.name] = function + changed = False for function in module.functions: - if self.with_function(env, type_env, function): - found = True - return found + if self.with_function(ctx, function): + changed = True + return changed - def with_function(self, env: Environment, type_env: TypeEnvironment, function: parse.Function) -> bool: - function_env = env.copy() + def with_function(self, ctx: Context, function: parse.Function) -> bool: + function_ctx = ctx.copy() + function_ctx.current_function = function for argument in function.arguments: - function_env[argument.name] = argument + function_ctx.environment[argument.name] = argument assert isinstance(function.type, parse.FunctionTypeUsage) - type, changed = type_compare(type_env, function.block.type, function.type.return_type) - function.block.type = type - function.type.return_type = type - if self.with_block(function_env, type_env, function.block): - changed = True + changed = self.with_block(function_ctx, function.block) + + if not (isinstance(function.block.type, parse.DataTypeUsage) and function.block.type.name == parse.NEVER_TYPE): + type, compare_changed = type_compare(function_ctx, function.block.type, function.type.return_type) + function.block.type = type + function.type.return_type = type + if compare_changed is True: + changed = True return changed # Skip variable VariableDeclaration - def with_block(self, env: Environment, type_env: TypeEnvironment, block: parse.Block) -> bool: - block_env = env.copy() + def with_block(self, ctx: Context, block: parse.Block) -> bool: + block_ctx = ctx.copy() # if parent is void, must be statement # if parent is type, must be expression changed = False + for statement in block.statements: + if self.with_statement(block_ctx, statement): + changed = True final = block.statements[-1] if isinstance(final, parse.LetStatement): if isinstance(block.type, parse.UnknownTypeUsage): - found = True + changed = True block.type = parse.DataTypeUsage( - name=parse.Identifier(name=parse.UNIT_TYPE) + name=parse.UNIT_TYPE ) else: assert block.type == parse.DataTypeUsage( - name=parse.Identifier(name=parse.UNIT_TYPE) + name=parse.UNIT_TYPE + ) + elif isinstance(final, parse.ReturnStatement): + if isinstance(block.type, parse.UnknownTypeUsage): + changed = True + block.type = parse.DataTypeUsage( + name=parse.NEVER_TYPE + ) + else: + assert block.type == parse.DataTypeUsage( + name=parse.NEVER_TYPE ) elif isinstance(final, parse.Expression): - if unify(type_env, final, block): - changed = True - - for statement in block.statements: - if self.with_statement(block_env, type_env, statement): + if unify(block_ctx, final, block): changed = True return changed - def with_statement(self, env: Environment, type_env: TypeEnvironment, statement: parse.Statement) -> bool: + def with_statement(self, ctx: Context, statement: parse.Statement) -> bool: + if isinstance(statement, parse.ReturnStatement): + return self.with_return_statement(ctx, statement) if isinstance(statement, parse.LetStatement): - return self.with_let_statement(env, type_env, statement) + return self.with_let_statement(ctx, statement) elif isinstance(statement, parse.Expression): # expression - return self.with_expression(env, type_env, statement) + return self.with_expression(ctx, statement) else: assert False def with_let_statement( - self, env: Environment, type_env: TypeEnvironment, let_statement: parse.LetStatement + self, ctx: Context, let_statement: parse.LetStatement ) -> bool: - found = False - env[let_statement.variable_name] = let_statement - changed = unify(type_env, let_statement, let_statement.expression) - if self.with_expression(env, type_env, let_statement.expression): + changed = False + ctx.environment[let_statement.variable_name] = let_statement + if self.with_expression(ctx, let_statement.expression): changed = True + changed = unify(ctx, let_statement, let_statement.expression) return changed - def with_expression(self, env: Environment, type_env: TypeEnvironment, expression: parse.Expression) -> bool: + def with_return_statement( + self, ctx: Context, return_statement: parse.ReturnStatement + ) -> bool: + changed = self.with_expression(ctx, return_statement.source) + + # Doesn't match on an unreachable return + if not (isinstance(return_statement.source.type, parse.DataTypeUsage) and return_statement.source.type.name == parse.NEVER_TYPE): + type, compare_changed = type_compare(ctx, return_statement.source.type, ctx.current_function.type.return_type) + return_statement.source.type = type + ctx.current_function.type.return_type = type + if compare_changed is True: + changed = True + return changed + + def with_expression(self, ctx: Context, expression: parse.Expression) -> bool: subexpression = expression.expression - changed = unify(type_env, subexpression, expression) + changed = False if isinstance(subexpression, parse.LiteralInt): - if self.with_literal_int(env, type_env, subexpression): + changed = self.with_literal_int(ctx, subexpression) + if unify(ctx, subexpression, expression): changed = True return changed if isinstance(subexpression, parse.LiteralFloat): - if self.with_literal_float(env, type_env, subexpression): + changed = self.with_literal_float(ctx, subexpression) + if unify(ctx, subexpression, expression): changed = True return changed if isinstance(subexpression, parse.FunctionCall): - if self.with_function_call(env, type_env, subexpression): + changed = self.with_function_call(ctx, subexpression) + if unify(ctx, subexpression, expression): changed = True return changed if isinstance(subexpression, parse.Block): - if self.with_block(env, type_env, subexpression): + changed = self.with_block(ctx, subexpression) + if unify(ctx, subexpression, expression): changed = True return changed if isinstance(subexpression, parse.VariableUsage): - if self.with_variable_usage(env, type_env, subexpression): + changed = self.with_variable_usage(ctx, subexpression) + if unify(ctx, subexpression, expression): changed = True return changed if isinstance(subexpression, parse.Operation): - if self.with_operation(env, type_env, subexpression): + changed = self.with_operation(ctx, subexpression) + if unify(ctx, subexpression, expression): changed = True return changed assert False def with_variable_usage( - self, env: Environment, type_env: TypeEnvironment, variable_usage: parse.VariableUsage + self, ctx: Context, variable_usage: parse.VariableUsage ) -> bool: - return unify(type_env, variable_usage, env[variable_usage.name]) + return unify(ctx, variable_usage, ctx.environment[variable_usage.name]) - def with_operation(self, env: Environment, type_env: TypeEnvironment, operation: parse.Operation) -> bool: + def with_operation(self, ctx: Context, operation: parse.Operation) -> bool: changed = False - if unify(type_env, operation, operation.left): + if self.with_expression(ctx, operation.left): changed = True - if unify(type_env, operation, operation.right): + if self.with_expression(ctx, operation.right): changed = True - if self.with_expression(env, type_env, operation.left): + if unify(ctx, operation, operation.left): changed = True - if self.with_expression(env, type_env, operation.right): + if unify(ctx, operation, operation.right): changed = True return changed def with_function_call( - self, env: Environment, type_env: TypeEnvironment, function_call: parse.FunctionCall + self, ctx: Context, function_call: parse.FunctionCall ) -> bool: changed = False if isinstance(function_call.source.type, parse.UnknownTypeUsage): @@ -180,14 +224,14 @@ class TypeChecker: return_type=parse.UnknownTypeUsage(), ) changed = True - if self.with_expression(env, type_env, function_call.source): + if self.with_expression(ctx, function_call.source): changed = True for argument in function_call.arguments: - if self.with_expression(env, type_env, argument): + if self.with_expression(ctx, argument): changed = True return_type, return_changed = type_compare( - type_env, function_call.type, function_call.source.type.return_type + ctx, function_call.type, function_call.source.type.return_type ) function_call.type = return_type function_call.source.type.return_type = return_type @@ -198,7 +242,7 @@ class TypeChecker: function_call.arguments, function_call.source.type.arguments ): argument_out_type, argument_changed = type_compare( - type_env, argument.type, function_call.source.type.return_type + ctx, argument.type, function_call.source.type.return_type ) argument.type = argument_out_type function_call.source.type.return_type = argument_out_type @@ -206,13 +250,13 @@ class TypeChecker: changed = True return changed - def with_literal_float(self, env: Environment, type_env: TypeEnvironment, literal_float: parse.LiteralFloat) -> bool: + def with_literal_float(self, ctx: Context, literal_float: parse.LiteralFloat) -> bool: floats = ["F32", "F64", "F128"] if not isinstance(literal_float.type, parse.UnknownTypeUsage): assert literal_float.type.name in floats, f"{literal_float.type}" return False - def with_literal_int(self, env: Environment, type_env: TypeEnvironment, literal_int: parse.LiteralInt) -> bool: + def with_literal_int(self, ctx: Context, literal_int: parse.LiteralInt) -> bool: ints = ["I8", "I16", "I32", "I64", "I128", "U8", "U16", "U32", "U64", "U128"] if not isinstance(literal_int.type, parse.UnknownTypeUsage): assert literal_int.type.name in ints, f"{literal_int.type}" diff --git a/examples/math/main.bl b/examples/math/main.bl index 731a848..4e39b96 100644 --- a/examples/math/main.bl +++ b/examples/math/main.bl @@ -1,6 +1,6 @@ // adds a and b, but also 4 for some reason fn add(a: I32, b: I32): I32 { - let foo = 4; + let foo = 4; // because I feel like it let test_float: F32 = { 10.2 }; @@ -11,6 +11,18 @@ fn subtract(a: I32, b: I32): I32 { a - b } +fn return_type_test(a: F64): F64 { + return a * 2.0; +} + +fn i_hate_this(a: F64): F64 { + return { + return { + return a; + }; + }; +} + fn main(): I32 { add(4, subtract(5, 2)) }