added typedefs and type_env

This commit is contained in:
Andrew Segavac
2021-05-29 10:33:14 -06:00
parent f5fc6643fb
commit b8769f43e3
3 changed files with 118 additions and 41 deletions

View File

@@ -2,6 +2,7 @@ import sys
from typing import List
from boring.parse import boring_parser, TreeToBoring, pretty_print
from boring.type_checking import TypeChecker
from boring import typedefs
if __name__ == "__main__":
with open(sys.argv[1]) as f:
@@ -10,7 +11,7 @@ if __name__ == "__main__":
result = TreeToBoring().transform(tree)
# pretty_print(result)
type_checker = TypeChecker()
while type_checker.with_module({}, result):
while type_checker.with_module({}, typedefs.builtins, result):
print("loop")
# type_checker.with_module({}, result)
pretty_print(result)

View File

@@ -2,21 +2,22 @@ from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from boring import parse
from boring import parse, typedefs
Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration]
Environment = Dict[str, Identified]
TypeEnvironment = Dict[str, typedefs.TypeDef]
def unify(first, second) -> bool:
result, changed = type_compare(first.type, second.type)
def unify(type_env: TypeEnvironment, first, second) -> bool:
result, changed = type_compare(type_env, first.type, second.type)
first.type = result
second.type = result
return changed
def type_compare(first, second) -> (parse.TypeUsage, bool):
def type_compare(type_env: TypeEnvironment, first, second) -> (parse.TypeUsage, bool):
print(first, second)
if isinstance(first, parse.UnknownTypeUsage):
if not isinstance(second, parse.UnknownTypeUsage):
@@ -31,18 +32,20 @@ def type_compare(first, second) -> (parse.TypeUsage, bool):
second, parse.DataTypeUsage
):
assert second == first
assert first.name in type_env
assert second.name in type_env
return first, False
elif isinstance(first, parse.FunctionTypeUsage) and isinstance(
second, parse.FunctionTypeUsage
):
return_type, changed = type_compare(
first.return_type, second.return_type
type_env, first.return_type, second.return_type
)
arguments = []
assert len(first.arguments) == len(second.arguments)
for first_arg, second_arg in zip(first.arguments, second.arguments):
argument_type, argument_changed = type_compare(
first_arg, second_arg
type_env, first_arg, second_arg
)
arguments.append(argument_type)
if argument_changed:
@@ -53,31 +56,31 @@ def type_compare(first, second) -> (parse.TypeUsage, bool):
class TypeChecker:
def with_module(self, env: Environment, module: parse.Module) -> bool:
def with_module(self, env: Environment, type_env: TypeEnvironment, module: parse.Module) -> bool:
for function in module.functions:
env[function.name] = function
found = False
for function in module.functions:
if self.with_function(env, function):
if self.with_function(env, type_env, function):
found = True
return found
def with_function(self, env: Environment, function: parse.Function) -> bool:
def with_function(self, env: Environment, type_env: TypeEnvironment, function: parse.Function) -> bool:
function_env = env.copy()
for argument in function.arguments:
function_env[argument.name] = argument
assert isinstance(function.type, parse.FunctionTypeUsage)
type, changed = type_compare(function.block.type, function.type.return_type)
type, changed = type_compare(type_env, function.block.type, function.type.return_type)
function.block.type = type
function.type.return_type = type
if self.with_block(function_env, function.block):
if self.with_block(function_env, type_env, function.block):
changed = True
return changed
# Skip variable VariableDeclaration
def with_block(self, env: Environment, block: parse.Block) -> bool:
def with_block(self, env: Environment, type_env: TypeEnvironment, block: parse.Block) -> bool:
block_env = env.copy()
# if parent is void, must be statement
# if parent is type, must be expression
@@ -94,74 +97,73 @@ class TypeChecker:
name=parse.Identifier(name=parse.UNIT_TYPE)
)
elif isinstance(final, parse.Expression):
if unify(final, block):
if unify(type_env, final, block):
changed = True
for statement in block.statements:
if self.with_statement(block_env, statement):
if self.with_statement(block_env, type_env, statement):
changed = True
return changed
def with_statement(self, env: Environment, statement: parse.Statement) -> bool:
def with_statement(self, env: Environment, type_env: TypeEnvironment, statement: parse.Statement) -> bool:
if isinstance(statement, parse.LetStatement):
return self.with_let_statement(env, statement)
return self.with_let_statement(env, type_env, statement)
elif isinstance(statement, parse.Expression): # expression
return self.with_expression(env, statement)
return self.with_expression(env, type_env, statement)
else:
assert False
def with_let_statement(
self, env: Environment, let_statement: parse.LetStatement
self, env: Environment, type_env: TypeEnvironment, let_statement: parse.LetStatement
) -> bool:
found = False
env[let_statement.variable_name] = let_statement
changed = unify(let_statement, let_statement.expression)
if self.with_expression(env, let_statement.expression):
changed = unify(type_env, let_statement, let_statement.expression)
if self.with_expression(env, type_env, let_statement.expression):
changed = True
return changed
def with_expression(self, env: Environment, expression: parse.Expression) -> bool:
def with_expression(self, env: Environment, type_env: TypeEnvironment, expression: parse.Expression) -> bool:
subexpression = expression.expression
changed = unify(subexpression, expression)
changed = unify(type_env, subexpression, expression)
if isinstance(subexpression, parse.LiteralInt):
print(f"fooooo {expression.type}, {subexpression.type}")
if self.with_literal_int(env, subexpression):
if self.with_literal_int(env, type_env, subexpression):
changed = True
return changed
if isinstance(subexpression, parse.FunctionCall):
if self.with_function_call(env, subexpression):
if self.with_function_call(env, type_env, subexpression):
changed = True
return changed
if isinstance(subexpression, parse.VariableUsage):
if self.with_variable_usage(env, subexpression):
if self.with_variable_usage(env, type_env, subexpression):
changed = True
return changed
if isinstance(subexpression, parse.Operation):
if self.with_operation(env, subexpression):
if self.with_operation(env, type_env, subexpression):
changed = True
return changed
assert False
def with_variable_usage(
self, env: Environment, variable_usage: parse.VariableUsage
self, env: Environment, type_env: TypeEnvironment, variable_usage: parse.VariableUsage
) -> bool:
return unify(variable_usage, env[variable_usage.name])
return unify(type_env, variable_usage, env[variable_usage.name])
def with_operation(self, env: Environment, operation: parse.Operation) -> bool:
def with_operation(self, env: Environment, type_env: TypeEnvironment, operation: parse.Operation) -> bool:
changed = False
if unify(operation, operation.left):
if unify(type_env, operation, operation.left):
changed = True
if unify(operation, operation.right):
if unify(type_env, operation, operation.right):
changed = True
if self.with_expression(env, operation.left):
if self.with_expression(env, type_env, operation.left):
changed = True
if self.with_expression(env, operation.right):
if self.with_expression(env, type_env, operation.right):
changed = True
return changed
def with_function_call(
self, env: Environment, function_call: parse.FunctionCall
self, env: Environment, type_env: TypeEnvironment, function_call: parse.FunctionCall
) -> bool:
changed = False
if isinstance(function_call.source.type, parse.UnknownTypeUsage):
@@ -170,14 +172,14 @@ class TypeChecker:
return_type=parse.UnknownTypeUsage(),
)
changed = True
if self.with_expression(env, function_call.source):
if self.with_expression(env, type_env, function_call.source):
changed = True
for argument in function_call.arguments:
if self.with_expression(env, argument):
if self.with_expression(env, type_env, argument):
changed = True
return_type, return_changed = type_compare(
function_call.type, function_call.source.type.return_type
type_env, function_call.type, function_call.source.type.return_type
)
function_call.type = return_type
function_call.source.type.return_type = return_type
@@ -188,7 +190,7 @@ class TypeChecker:
function_call.arguments, function_call.source.type.arguments
):
argument_out_type, argument_changed = type_compare(
argument.type, function_call.source.type.return_type
type_env, argument.type, function_call.source.type.return_type
)
argument.type = argument_out_type
function_call.source.type.return_type = argument_out_type
@@ -196,7 +198,7 @@ class TypeChecker:
changed = True
return changed
def with_literal_int(self, env: Environment, literal_int: parse.LiteralInt) -> bool:
def with_literal_int(self, env: Environment, type_env: TypeEnvironment, literal_int: parse.LiteralInt) -> bool:
ints = [
parse.DataTypeUsage(name=name)
for name in ["I8", "I16", "I32", "I64", "I128"]

74
boring/typedefs.py Normal file
View File

@@ -0,0 +1,74 @@
from dataclasses import dataclass, field
import enum
from typing import List, Dict, Optional, Union
class IntBitness(enum.Enum):
X8 = 'X8'
X16 = 'X16'
X32 = 'X32'
X64 = 'X64'
X128 = 'X128'
class Signedness(enum.Enum):
Signed = 'Signed'
Unsigned = 'Unsigned'
class FloatBitness(enum.Enum):
X32 = 'X32'
X64 = 'X64'
X128 = 'X128'
@dataclass
class IntTypeDef:
signedness: Signedness
bitness: IntBitness
@dataclass
class FloatTypeDef:
bitness: FloatBitness
@dataclass
class FunctionTypeDef:
arguments: List["TypeDef"]
return_type: "TypeDef"
@dataclass
class UnitTypeDef:
pass
@dataclass
class NeverTypeDef:
pass
TypeDef = Union[IntTypeDef, FloatTypeDef, FunctionTypeDef, UnitTypeDef, NeverTypeDef]
builtins: Dict[str, TypeDef] = {
'U8': IntTypeDef(Signedness.Unsigned, IntBitness.X8),
'U16': IntTypeDef(Signedness.Unsigned, IntBitness.X16),
'U32': IntTypeDef(Signedness.Unsigned, IntBitness.X32),
'U64': IntTypeDef(Signedness.Unsigned, IntBitness.X64),
'U128': IntTypeDef(Signedness.Unsigned, IntBitness.X128),
'I8': IntTypeDef(Signedness.Signed, IntBitness.X8),
'I16': IntTypeDef(Signedness.Signed, IntBitness.X16),
'I32': IntTypeDef(Signedness.Signed, IntBitness.X32),
'I64': IntTypeDef(Signedness.Signed, IntBitness.X64),
'I128': IntTypeDef(Signedness.Signed, IntBitness.X128),
'F32': FloatTypeDef(FloatBitness.X32),
'F64': FloatTypeDef(FloatBitness.X64),
'F128': FloatTypeDef(FloatBitness.X128),
'()': UnitTypeDef(),
'!': NeverTypeDef(),
}