added return statement
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
from typing import List
|
from typing import List
|
||||||
from boring.parse import boring_parser, TreeToBoring, pretty_print
|
from boring.parse import boring_parser, TreeToBoring, pretty_print
|
||||||
from boring.type_checking import TypeChecker
|
from boring.type_checking import TypeChecker, Context
|
||||||
from boring import typedefs
|
from boring import typedefs
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -11,7 +11,7 @@ if __name__ == "__main__":
|
|||||||
result = TreeToBoring().transform(tree)
|
result = TreeToBoring().transform(tree)
|
||||||
# pretty_print(result)
|
# pretty_print(result)
|
||||||
type_checker = TypeChecker()
|
type_checker = TypeChecker()
|
||||||
while type_checker.with_module({}, typedefs.builtins, result):
|
while type_checker.with_module(Context({}, typedefs.builtins, None), result):
|
||||||
print("loop")
|
print("loop")
|
||||||
# type_checker.with_module({}, result)
|
# type_checker.with_module({}, result)
|
||||||
pretty_print(result)
|
pretty_print(result)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ def pretty_print(clas, indent=0):
|
|||||||
|
|
||||||
|
|
||||||
UNIT_TYPE = "()"
|
UNIT_TYPE = "()"
|
||||||
|
NEVER_TYPE = "!"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -89,9 +90,15 @@ class VariableUsage:
|
|||||||
type: TypeUsage
|
type: TypeUsage
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReturnStatement:
|
||||||
|
source: "Expression"
|
||||||
|
type: TypeUsage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Expression:
|
class Expression:
|
||||||
expression: Union[LiteralInt, LiteralFloat, FunctionCall, "Block", VariableUsage, Operation]
|
expression: Union[LiteralInt, LiteralFloat, FunctionCall, "Block", ReturnStatement, VariableUsage, Operation]
|
||||||
type: TypeUsage
|
type: TypeUsage
|
||||||
|
|
||||||
|
|
||||||
@@ -150,6 +157,8 @@ boring_grammar = r"""
|
|||||||
|
|
||||||
variable_usage : identifier
|
variable_usage : identifier
|
||||||
|
|
||||||
|
return_statement : "return" expression ";"
|
||||||
|
|
||||||
expression : add_expression
|
expression : add_expression
|
||||||
| sub_expression
|
| sub_expression
|
||||||
| factor
|
| factor
|
||||||
@@ -169,6 +178,7 @@ boring_grammar = r"""
|
|||||||
| "let" identifier ":" type_usage "=" expression ";"
|
| "let" identifier ":" type_usage "=" expression ";"
|
||||||
|
|
||||||
statement : let_statement
|
statement : let_statement
|
||||||
|
| return_statement
|
||||||
| expression
|
| expression
|
||||||
|
|
||||||
block : "{" (statement)* "}"
|
block : "{" (statement)* "}"
|
||||||
@@ -238,6 +248,10 @@ class TreeToBoring(Transformer):
|
|||||||
(variable,) = variable
|
(variable,) = variable
|
||||||
return VariableUsage(name=variable, type=UnknownTypeUsage())
|
return VariableUsage(name=variable, type=UnknownTypeUsage())
|
||||||
|
|
||||||
|
def return_statement(self, return_expression) -> ReturnStatement:
|
||||||
|
(return_expression,) = return_expression
|
||||||
|
return ReturnStatement(source=return_expression, type=DataTypeUsage(name=NEVER_TYPE))
|
||||||
|
|
||||||
def function_call(self, call) -> FunctionCall:
|
def function_call(self, call) -> FunctionCall:
|
||||||
return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage())
|
return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage())
|
||||||
|
|
||||||
|
|||||||
@@ -9,15 +9,24 @@ Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration
|
|||||||
Environment = Dict[str, Identified]
|
Environment = Dict[str, Identified]
|
||||||
TypeEnvironment = Dict[str, typedefs.TypeDef]
|
TypeEnvironment = Dict[str, typedefs.TypeDef]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Context:
|
||||||
|
environment: Environment
|
||||||
|
type_environment: TypeEnvironment
|
||||||
|
current_function: Optional[parse.Function]
|
||||||
|
|
||||||
def unify(type_env: TypeEnvironment, first, second) -> bool:
|
def copy(self):
|
||||||
result, changed = type_compare(type_env, first.type, second.type)
|
return Context(self.environment.copy(), self.type_environment.copy(), self.current_function)
|
||||||
|
|
||||||
|
|
||||||
|
def unify(ctx: Context, first, second) -> bool:
|
||||||
|
result, changed = type_compare(ctx, first.type, second.type)
|
||||||
first.type = result
|
first.type = result
|
||||||
second.type = result
|
second.type = result
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
|
|
||||||
def type_compare(type_env: TypeEnvironment, first, second) -> (parse.TypeUsage, bool):
|
def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
|
||||||
print(first, second)
|
print(first, second)
|
||||||
if isinstance(first, parse.UnknownTypeUsage):
|
if isinstance(first, parse.UnknownTypeUsage):
|
||||||
if not isinstance(second, parse.UnknownTypeUsage):
|
if not isinstance(second, parse.UnknownTypeUsage):
|
||||||
@@ -32,20 +41,20 @@ def type_compare(type_env: TypeEnvironment, first, second) -> (parse.TypeUsage,
|
|||||||
second, parse.DataTypeUsage
|
second, parse.DataTypeUsage
|
||||||
):
|
):
|
||||||
assert second == first
|
assert second == first
|
||||||
assert first.name in type_env
|
assert first.name in ctx.type_environment
|
||||||
assert second.name in type_env
|
assert second.name in ctx.type_environment
|
||||||
return first, False
|
return first, False
|
||||||
elif isinstance(first, parse.FunctionTypeUsage) and isinstance(
|
elif isinstance(first, parse.FunctionTypeUsage) and isinstance(
|
||||||
second, parse.FunctionTypeUsage
|
second, parse.FunctionTypeUsage
|
||||||
):
|
):
|
||||||
return_type, changed = type_compare(
|
return_type, changed = type_compare(
|
||||||
type_env, first.return_type, second.return_type
|
ctx, first.return_type, second.return_type
|
||||||
)
|
)
|
||||||
arguments = []
|
arguments = []
|
||||||
assert len(first.arguments) == len(second.arguments)
|
assert len(first.arguments) == len(second.arguments)
|
||||||
for first_arg, second_arg in zip(first.arguments, second.arguments):
|
for first_arg, second_arg in zip(first.arguments, second.arguments):
|
||||||
argument_type, argument_changed = type_compare(
|
argument_type, argument_changed = type_compare(
|
||||||
type_env, first_arg, second_arg
|
ctx, first_arg, second_arg
|
||||||
)
|
)
|
||||||
arguments.append(argument_type)
|
arguments.append(argument_type)
|
||||||
if argument_changed:
|
if argument_changed:
|
||||||
@@ -56,122 +65,157 @@ def type_compare(type_env: TypeEnvironment, first, second) -> (parse.TypeUsage,
|
|||||||
|
|
||||||
|
|
||||||
class TypeChecker:
|
class TypeChecker:
|
||||||
def with_module(self, env: Environment, type_env: TypeEnvironment, module: parse.Module) -> bool:
|
def with_module(self, ctx: Context, module: parse.Module) -> bool:
|
||||||
for function in module.functions:
|
for function in module.functions:
|
||||||
env[function.name] = function
|
ctx.environment[function.name] = function
|
||||||
found = False
|
changed = False
|
||||||
for function in module.functions:
|
for function in module.functions:
|
||||||
if self.with_function(env, type_env, function):
|
if self.with_function(ctx, function):
|
||||||
found = True
|
changed = True
|
||||||
return found
|
return changed
|
||||||
|
|
||||||
def with_function(self, env: Environment, type_env: TypeEnvironment, function: parse.Function) -> bool:
|
def with_function(self, ctx: Context, function: parse.Function) -> bool:
|
||||||
function_env = env.copy()
|
function_ctx = ctx.copy()
|
||||||
|
function_ctx.current_function = function
|
||||||
for argument in function.arguments:
|
for argument in function.arguments:
|
||||||
function_env[argument.name] = argument
|
function_ctx.environment[argument.name] = argument
|
||||||
assert isinstance(function.type, parse.FunctionTypeUsage)
|
assert isinstance(function.type, parse.FunctionTypeUsage)
|
||||||
|
|
||||||
type, changed = type_compare(type_env, function.block.type, function.type.return_type)
|
changed = self.with_block(function_ctx, function.block)
|
||||||
|
|
||||||
|
if not (isinstance(function.block.type, parse.DataTypeUsage) and function.block.type.name == parse.NEVER_TYPE):
|
||||||
|
type, compare_changed = type_compare(function_ctx, function.block.type, function.type.return_type)
|
||||||
function.block.type = type
|
function.block.type = type
|
||||||
function.type.return_type = type
|
function.type.return_type = type
|
||||||
if self.with_block(function_env, type_env, function.block):
|
if compare_changed is True:
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
# Skip variable VariableDeclaration
|
# Skip variable VariableDeclaration
|
||||||
|
|
||||||
def with_block(self, env: Environment, type_env: TypeEnvironment, block: parse.Block) -> bool:
|
def with_block(self, ctx: Context, block: parse.Block) -> bool:
|
||||||
block_env = env.copy()
|
block_ctx = ctx.copy()
|
||||||
# if parent is void, must be statement
|
# if parent is void, must be statement
|
||||||
# if parent is type, must be expression
|
# if parent is type, must be expression
|
||||||
changed = False
|
changed = False
|
||||||
|
for statement in block.statements:
|
||||||
|
if self.with_statement(block_ctx, statement):
|
||||||
|
changed = True
|
||||||
final = block.statements[-1]
|
final = block.statements[-1]
|
||||||
if isinstance(final, parse.LetStatement):
|
if isinstance(final, parse.LetStatement):
|
||||||
if isinstance(block.type, parse.UnknownTypeUsage):
|
if isinstance(block.type, parse.UnknownTypeUsage):
|
||||||
found = True
|
changed = True
|
||||||
block.type = parse.DataTypeUsage(
|
block.type = parse.DataTypeUsage(
|
||||||
name=parse.Identifier(name=parse.UNIT_TYPE)
|
name=parse.UNIT_TYPE
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert block.type == parse.DataTypeUsage(
|
assert block.type == parse.DataTypeUsage(
|
||||||
name=parse.Identifier(name=parse.UNIT_TYPE)
|
name=parse.UNIT_TYPE
|
||||||
|
)
|
||||||
|
elif isinstance(final, parse.ReturnStatement):
|
||||||
|
if isinstance(block.type, parse.UnknownTypeUsage):
|
||||||
|
changed = True
|
||||||
|
block.type = parse.DataTypeUsage(
|
||||||
|
name=parse.NEVER_TYPE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert block.type == parse.DataTypeUsage(
|
||||||
|
name=parse.NEVER_TYPE
|
||||||
)
|
)
|
||||||
elif isinstance(final, parse.Expression):
|
elif isinstance(final, parse.Expression):
|
||||||
if unify(type_env, final, block):
|
if unify(block_ctx, final, block):
|
||||||
changed = True
|
|
||||||
|
|
||||||
for statement in block.statements:
|
|
||||||
if self.with_statement(block_env, type_env, statement):
|
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
def with_statement(self, env: Environment, type_env: TypeEnvironment, statement: parse.Statement) -> bool:
|
def with_statement(self, ctx: Context, statement: parse.Statement) -> bool:
|
||||||
|
if isinstance(statement, parse.ReturnStatement):
|
||||||
|
return self.with_return_statement(ctx, statement)
|
||||||
if isinstance(statement, parse.LetStatement):
|
if isinstance(statement, parse.LetStatement):
|
||||||
return self.with_let_statement(env, type_env, statement)
|
return self.with_let_statement(ctx, statement)
|
||||||
elif isinstance(statement, parse.Expression): # expression
|
elif isinstance(statement, parse.Expression): # expression
|
||||||
return self.with_expression(env, type_env, statement)
|
return self.with_expression(ctx, statement)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
def with_let_statement(
|
def with_let_statement(
|
||||||
self, env: Environment, type_env: TypeEnvironment, let_statement: parse.LetStatement
|
self, ctx: Context, let_statement: parse.LetStatement
|
||||||
) -> bool:
|
) -> bool:
|
||||||
found = False
|
changed = False
|
||||||
env[let_statement.variable_name] = let_statement
|
ctx.environment[let_statement.variable_name] = let_statement
|
||||||
changed = unify(type_env, let_statement, let_statement.expression)
|
if self.with_expression(ctx, let_statement.expression):
|
||||||
if self.with_expression(env, type_env, let_statement.expression):
|
changed = True
|
||||||
|
changed = unify(ctx, let_statement, let_statement.expression)
|
||||||
|
return changed
|
||||||
|
|
||||||
|
def with_return_statement(
|
||||||
|
self, ctx: Context, return_statement: parse.ReturnStatement
|
||||||
|
) -> bool:
|
||||||
|
changed = self.with_expression(ctx, return_statement.source)
|
||||||
|
|
||||||
|
# Doesn't match on an unreachable return
|
||||||
|
if not (isinstance(return_statement.source.type, parse.DataTypeUsage) and return_statement.source.type.name == parse.NEVER_TYPE):
|
||||||
|
type, compare_changed = type_compare(ctx, return_statement.source.type, ctx.current_function.type.return_type)
|
||||||
|
return_statement.source.type = type
|
||||||
|
ctx.current_function.type.return_type = type
|
||||||
|
if compare_changed is True:
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
def with_expression(self, env: Environment, type_env: TypeEnvironment, expression: parse.Expression) -> bool:
|
def with_expression(self, ctx: Context, expression: parse.Expression) -> bool:
|
||||||
subexpression = expression.expression
|
subexpression = expression.expression
|
||||||
changed = unify(type_env, subexpression, expression)
|
changed = False
|
||||||
|
|
||||||
if isinstance(subexpression, parse.LiteralInt):
|
if isinstance(subexpression, parse.LiteralInt):
|
||||||
if self.with_literal_int(env, type_env, subexpression):
|
changed = self.with_literal_int(ctx, subexpression)
|
||||||
|
if unify(ctx, subexpression, expression):
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
if isinstance(subexpression, parse.LiteralFloat):
|
if isinstance(subexpression, parse.LiteralFloat):
|
||||||
if self.with_literal_float(env, type_env, subexpression):
|
changed = self.with_literal_float(ctx, subexpression)
|
||||||
|
if unify(ctx, subexpression, expression):
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
if isinstance(subexpression, parse.FunctionCall):
|
if isinstance(subexpression, parse.FunctionCall):
|
||||||
if self.with_function_call(env, type_env, subexpression):
|
changed = self.with_function_call(ctx, subexpression)
|
||||||
|
if unify(ctx, subexpression, expression):
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
if isinstance(subexpression, parse.Block):
|
if isinstance(subexpression, parse.Block):
|
||||||
if self.with_block(env, type_env, subexpression):
|
changed = self.with_block(ctx, subexpression)
|
||||||
|
if unify(ctx, subexpression, expression):
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
if isinstance(subexpression, parse.VariableUsage):
|
if isinstance(subexpression, parse.VariableUsage):
|
||||||
if self.with_variable_usage(env, type_env, subexpression):
|
changed = self.with_variable_usage(ctx, subexpression)
|
||||||
|
if unify(ctx, subexpression, expression):
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
if isinstance(subexpression, parse.Operation):
|
if isinstance(subexpression, parse.Operation):
|
||||||
if self.with_operation(env, type_env, subexpression):
|
changed = self.with_operation(ctx, subexpression)
|
||||||
|
if unify(ctx, subexpression, expression):
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
def with_variable_usage(
|
def with_variable_usage(
|
||||||
self, env: Environment, type_env: TypeEnvironment, variable_usage: parse.VariableUsage
|
self, ctx: Context, variable_usage: parse.VariableUsage
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return unify(type_env, variable_usage, env[variable_usage.name])
|
return unify(ctx, variable_usage, ctx.environment[variable_usage.name])
|
||||||
|
|
||||||
def with_operation(self, env: Environment, type_env: TypeEnvironment, operation: parse.Operation) -> bool:
|
def with_operation(self, ctx: Context, operation: parse.Operation) -> bool:
|
||||||
changed = False
|
changed = False
|
||||||
if unify(type_env, operation, operation.left):
|
if self.with_expression(ctx, operation.left):
|
||||||
changed = True
|
changed = True
|
||||||
if unify(type_env, operation, operation.right):
|
if self.with_expression(ctx, operation.right):
|
||||||
changed = True
|
changed = True
|
||||||
if self.with_expression(env, type_env, operation.left):
|
if unify(ctx, operation, operation.left):
|
||||||
changed = True
|
changed = True
|
||||||
if self.with_expression(env, type_env, operation.right):
|
if unify(ctx, operation, operation.right):
|
||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
def with_function_call(
|
def with_function_call(
|
||||||
self, env: Environment, type_env: TypeEnvironment, function_call: parse.FunctionCall
|
self, ctx: Context, function_call: parse.FunctionCall
|
||||||
) -> bool:
|
) -> bool:
|
||||||
changed = False
|
changed = False
|
||||||
if isinstance(function_call.source.type, parse.UnknownTypeUsage):
|
if isinstance(function_call.source.type, parse.UnknownTypeUsage):
|
||||||
@@ -180,14 +224,14 @@ class TypeChecker:
|
|||||||
return_type=parse.UnknownTypeUsage(),
|
return_type=parse.UnknownTypeUsage(),
|
||||||
)
|
)
|
||||||
changed = True
|
changed = True
|
||||||
if self.with_expression(env, type_env, function_call.source):
|
if self.with_expression(ctx, function_call.source):
|
||||||
changed = True
|
changed = True
|
||||||
for argument in function_call.arguments:
|
for argument in function_call.arguments:
|
||||||
if self.with_expression(env, type_env, argument):
|
if self.with_expression(ctx, argument):
|
||||||
changed = True
|
changed = True
|
||||||
|
|
||||||
return_type, return_changed = type_compare(
|
return_type, return_changed = type_compare(
|
||||||
type_env, function_call.type, function_call.source.type.return_type
|
ctx, function_call.type, function_call.source.type.return_type
|
||||||
)
|
)
|
||||||
function_call.type = return_type
|
function_call.type = return_type
|
||||||
function_call.source.type.return_type = return_type
|
function_call.source.type.return_type = return_type
|
||||||
@@ -198,7 +242,7 @@ class TypeChecker:
|
|||||||
function_call.arguments, function_call.source.type.arguments
|
function_call.arguments, function_call.source.type.arguments
|
||||||
):
|
):
|
||||||
argument_out_type, argument_changed = type_compare(
|
argument_out_type, argument_changed = type_compare(
|
||||||
type_env, argument.type, function_call.source.type.return_type
|
ctx, argument.type, function_call.source.type.return_type
|
||||||
)
|
)
|
||||||
argument.type = argument_out_type
|
argument.type = argument_out_type
|
||||||
function_call.source.type.return_type = argument_out_type
|
function_call.source.type.return_type = argument_out_type
|
||||||
@@ -206,13 +250,13 @@ class TypeChecker:
|
|||||||
changed = True
|
changed = True
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
def with_literal_float(self, env: Environment, type_env: TypeEnvironment, literal_float: parse.LiteralFloat) -> bool:
|
def with_literal_float(self, ctx: Context, literal_float: parse.LiteralFloat) -> bool:
|
||||||
floats = ["F32", "F64", "F128"]
|
floats = ["F32", "F64", "F128"]
|
||||||
if not isinstance(literal_float.type, parse.UnknownTypeUsage):
|
if not isinstance(literal_float.type, parse.UnknownTypeUsage):
|
||||||
assert literal_float.type.name in floats, f"{literal_float.type}"
|
assert literal_float.type.name in floats, f"{literal_float.type}"
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def with_literal_int(self, env: Environment, type_env: TypeEnvironment, literal_int: parse.LiteralInt) -> bool:
|
def with_literal_int(self, ctx: Context, literal_int: parse.LiteralInt) -> bool:
|
||||||
ints = ["I8", "I16", "I32", "I64", "I128", "U8", "U16", "U32", "U64", "U128"]
|
ints = ["I8", "I16", "I32", "I64", "I128", "U8", "U16", "U32", "U64", "U128"]
|
||||||
if not isinstance(literal_int.type, parse.UnknownTypeUsage):
|
if not isinstance(literal_int.type, parse.UnknownTypeUsage):
|
||||||
assert literal_int.type.name in ints, f"{literal_int.type}"
|
assert literal_int.type.name in ints, f"{literal_int.type}"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// adds a and b, but also 4 for some reason
|
// adds a and b, but also 4 for some reason
|
||||||
fn add(a: I32, b: I32): I32 {
|
fn add(a: I32, b: I32): I32 {
|
||||||
let foo = 4;
|
let foo = 4; // because I feel like it
|
||||||
let test_float: F32 = {
|
let test_float: F32 = {
|
||||||
10.2
|
10.2
|
||||||
};
|
};
|
||||||
@@ -11,6 +11,18 @@ fn subtract(a: I32, b: I32): I32 {
|
|||||||
a - b
|
a - b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn return_type_test(a: F64): F64 {
|
||||||
|
return a * 2.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn i_hate_this(a: F64): F64 {
|
||||||
|
return {
|
||||||
|
return {
|
||||||
|
return a;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
fn main(): I32 {
|
fn main(): I32 {
|
||||||
add(4, subtract(5, 2))
|
add(4, subtract(5, 2))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user