diff --git a/boring/main.py b/boring/main.py index dae2295..c480fff 100644 --- a/boring/main.py +++ b/boring/main.py @@ -2,6 +2,7 @@ import sys from typing import List from boring.parse import boring_parser, TreeToBoring, pretty_print from boring.type_checking import TypeChecker +from boring import typedefs if __name__ == "__main__": with open(sys.argv[1]) as f: @@ -10,7 +11,7 @@ if __name__ == "__main__": result = TreeToBoring().transform(tree) # pretty_print(result) type_checker = TypeChecker() - while type_checker.with_module({}, result): + while type_checker.with_module({}, typedefs.builtins, result): print("loop") # type_checker.with_module({}, result) pretty_print(result) diff --git a/boring/type_checking.py b/boring/type_checking.py index ad8b554..b88b24f 100644 --- a/boring/type_checking.py +++ b/boring/type_checking.py @@ -2,21 +2,22 @@ from dataclasses import dataclass from typing import List, Dict, Optional, Union -from boring import parse +from boring import parse, typedefs Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration] Environment = Dict[str, Identified] +TypeEnvironment = Dict[str, typedefs.TypeDef] -def unify(first, second) -> bool: - result, changed = type_compare(first.type, second.type) +def unify(type_env: TypeEnvironment, first, second) -> bool: + result, changed = type_compare(type_env, first.type, second.type) first.type = result second.type = result return changed -def type_compare(first, second) -> (parse.TypeUsage, bool): +def type_compare(type_env: TypeEnvironment, first, second) -> (parse.TypeUsage, bool): print(first, second) if isinstance(first, parse.UnknownTypeUsage): if not isinstance(second, parse.UnknownTypeUsage): @@ -31,18 +32,20 @@ def type_compare(first, second) -> (parse.TypeUsage, bool): second, parse.DataTypeUsage ): assert second == first + assert first.name in type_env + assert second.name in type_env return first, False elif isinstance(first, parse.FunctionTypeUsage) and isinstance( second, parse.FunctionTypeUsage ): return_type, changed = type_compare( - first.return_type, second.return_type + type_env, first.return_type, second.return_type ) arguments = [] assert len(first.arguments) == len(second.arguments) for first_arg, second_arg in zip(first.arguments, second.arguments): argument_type, argument_changed = type_compare( - first_arg, second_arg + type_env, first_arg, second_arg ) arguments.append(argument_type) if argument_changed: @@ -53,31 +56,31 @@ def type_compare(first, second) -> (parse.TypeUsage, bool): class TypeChecker: - def with_module(self, env: Environment, module: parse.Module) -> bool: + def with_module(self, env: Environment, type_env: TypeEnvironment, module: parse.Module) -> bool: for function in module.functions: env[function.name] = function found = False for function in module.functions: - if self.with_function(env, function): + if self.with_function(env, type_env, function): found = True return found - def with_function(self, env: Environment, function: parse.Function) -> bool: + def with_function(self, env: Environment, type_env: TypeEnvironment, function: parse.Function) -> bool: function_env = env.copy() for argument in function.arguments: function_env[argument.name] = argument assert isinstance(function.type, parse.FunctionTypeUsage) - type, changed = type_compare(function.block.type, function.type.return_type) + type, changed = type_compare(type_env, function.block.type, function.type.return_type) function.block.type = type function.type.return_type = type - if self.with_block(function_env, function.block): + if self.with_block(function_env, type_env, function.block): changed = True return changed # Skip variable VariableDeclaration - def with_block(self, env: Environment, block: parse.Block) -> bool: + def with_block(self, env: Environment, type_env: TypeEnvironment, block: parse.Block) -> bool: block_env = env.copy() # if parent is void, must be statement # if parent is type, must be expression @@ -94,74 +97,73 @@ class TypeChecker: name=parse.Identifier(name=parse.UNIT_TYPE) ) elif isinstance(final, parse.Expression): - if unify(final, block): + if unify(type_env, final, block): changed = True for statement in block.statements: - if self.with_statement(block_env, statement): + if self.with_statement(block_env, type_env, statement): changed = True return changed - def with_statement(self, env: Environment, statement: parse.Statement) -> bool: + def with_statement(self, env: Environment, type_env: TypeEnvironment, statement: parse.Statement) -> bool: if isinstance(statement, parse.LetStatement): - return self.with_let_statement(env, statement) + return self.with_let_statement(env, type_env, statement) elif isinstance(statement, parse.Expression): # expression - return self.with_expression(env, statement) + return self.with_expression(env, type_env, statement) else: assert False def with_let_statement( - self, env: Environment, let_statement: parse.LetStatement + self, env: Environment, type_env: TypeEnvironment, let_statement: parse.LetStatement ) -> bool: found = False env[let_statement.variable_name] = let_statement - changed = unify(let_statement, let_statement.expression) - if self.with_expression(env, let_statement.expression): + changed = unify(type_env, let_statement, let_statement.expression) + if self.with_expression(env, type_env, let_statement.expression): changed = True return changed - def with_expression(self, env: Environment, expression: parse.Expression) -> bool: + def with_expression(self, env: Environment, type_env: TypeEnvironment, expression: parse.Expression) -> bool: subexpression = expression.expression - changed = unify(subexpression, expression) + changed = unify(type_env, subexpression, expression) if isinstance(subexpression, parse.LiteralInt): - print(f"fooooo {expression.type}, {subexpression.type}") - if self.with_literal_int(env, subexpression): + if self.with_literal_int(env, type_env, subexpression): changed = True return changed if isinstance(subexpression, parse.FunctionCall): - if self.with_function_call(env, subexpression): + if self.with_function_call(env, type_env, subexpression): changed = True return changed if isinstance(subexpression, parse.VariableUsage): - if self.with_variable_usage(env, subexpression): + if self.with_variable_usage(env, type_env, subexpression): changed = True return changed if isinstance(subexpression, parse.Operation): - if self.with_operation(env, subexpression): + if self.with_operation(env, type_env, subexpression): changed = True return changed assert False def with_variable_usage( - self, env: Environment, variable_usage: parse.VariableUsage + self, env: Environment, type_env: TypeEnvironment, variable_usage: parse.VariableUsage ) -> bool: - return unify(variable_usage, env[variable_usage.name]) + return unify(type_env, variable_usage, env[variable_usage.name]) - def with_operation(self, env: Environment, operation: parse.Operation) -> bool: + def with_operation(self, env: Environment, type_env: TypeEnvironment, operation: parse.Operation) -> bool: changed = False - if unify(operation, operation.left): + if unify(type_env, operation, operation.left): changed = True - if unify(operation, operation.right): + if unify(type_env, operation, operation.right): changed = True - if self.with_expression(env, operation.left): + if self.with_expression(env, type_env, operation.left): changed = True - if self.with_expression(env, operation.right): + if self.with_expression(env, type_env, operation.right): changed = True return changed def with_function_call( - self, env: Environment, function_call: parse.FunctionCall + self, env: Environment, type_env: TypeEnvironment, function_call: parse.FunctionCall ) -> bool: changed = False if isinstance(function_call.source.type, parse.UnknownTypeUsage): @@ -170,14 +172,14 @@ class TypeChecker: return_type=parse.UnknownTypeUsage(), ) changed = True - if self.with_expression(env, function_call.source): + if self.with_expression(env, type_env, function_call.source): changed = True for argument in function_call.arguments: - if self.with_expression(env, argument): + if self.with_expression(env, type_env, argument): changed = True return_type, return_changed = type_compare( - function_call.type, function_call.source.type.return_type + type_env, function_call.type, function_call.source.type.return_type ) function_call.type = return_type function_call.source.type.return_type = return_type @@ -188,7 +190,7 @@ class TypeChecker: function_call.arguments, function_call.source.type.arguments ): argument_out_type, argument_changed = type_compare( - argument.type, function_call.source.type.return_type + type_env, argument.type, function_call.source.type.return_type ) argument.type = argument_out_type function_call.source.type.return_type = argument_out_type @@ -196,7 +198,7 @@ class TypeChecker: changed = True return changed - def with_literal_int(self, env: Environment, literal_int: parse.LiteralInt) -> bool: + def with_literal_int(self, env: Environment, type_env: TypeEnvironment, literal_int: parse.LiteralInt) -> bool: ints = [ parse.DataTypeUsage(name=name) for name in ["I8", "I16", "I32", "I64", "I128"] diff --git a/boring/typedefs.py b/boring/typedefs.py new file mode 100644 index 0000000..49371a9 --- /dev/null +++ b/boring/typedefs.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass, field +import enum +from typing import List, Dict, Optional, Union + + +class IntBitness(enum.Enum): + X8 = 'X8' + X16 = 'X16' + X32 = 'X32' + X64 = 'X64' + X128 = 'X128' + + +class Signedness(enum.Enum): + Signed = 'Signed' + Unsigned = 'Unsigned' + + +class FloatBitness(enum.Enum): + X32 = 'X32' + X64 = 'X64' + X128 = 'X128' + + +@dataclass +class IntTypeDef: + signedness: Signedness + bitness: IntBitness + + +@dataclass +class FloatTypeDef: + bitness: FloatBitness + + +@dataclass +class FunctionTypeDef: + arguments: List["TypeDef"] + return_type: "TypeDef" + + +@dataclass +class UnitTypeDef: + pass + + +@dataclass +class NeverTypeDef: + pass + + +TypeDef = Union[IntTypeDef, FloatTypeDef, FunctionTypeDef, UnitTypeDef, NeverTypeDef] + + +builtins: Dict[str, TypeDef] = { + 'U8': IntTypeDef(Signedness.Unsigned, IntBitness.X8), + 'U16': IntTypeDef(Signedness.Unsigned, IntBitness.X16), + 'U32': IntTypeDef(Signedness.Unsigned, IntBitness.X32), + 'U64': IntTypeDef(Signedness.Unsigned, IntBitness.X64), + 'U128': IntTypeDef(Signedness.Unsigned, IntBitness.X128), + + 'I8': IntTypeDef(Signedness.Signed, IntBitness.X8), + 'I16': IntTypeDef(Signedness.Signed, IntBitness.X16), + 'I32': IntTypeDef(Signedness.Signed, IntBitness.X32), + 'I64': IntTypeDef(Signedness.Signed, IntBitness.X64), + 'I128': IntTypeDef(Signedness.Signed, IntBitness.X128), + + 'F32': FloatTypeDef(FloatBitness.X32), + 'F64': FloatTypeDef(FloatBitness.X64), + 'F128': FloatTypeDef(FloatBitness.X128), + + '()': UnitTypeDef(), + '!': NeverTypeDef(), +}