got type checking working for real
This commit is contained in:
@@ -11,7 +11,7 @@ if __name__ == "__main__":
|
||||
# pretty_print(result)
|
||||
type_checker = TypeChecker()
|
||||
while type_checker.with_module({}, result):
|
||||
print('loop')
|
||||
print("loop")
|
||||
# type_checker.with_module({}, result)
|
||||
pretty_print(result)
|
||||
# tctb = TypeCheckTableBuilder()
|
||||
|
||||
@@ -28,11 +28,6 @@ def pretty_print(clas, indent=0):
|
||||
UNIT_TYPE = "()"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Identifier:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTypeUsage:
|
||||
arguments: List[
|
||||
@@ -43,10 +38,15 @@ class FunctionTypeUsage:
|
||||
|
||||
@dataclass
|
||||
class DataTypeUsage:
|
||||
name: Identifier
|
||||
name: str
|
||||
|
||||
|
||||
TypeUsage = Union[FunctionTypeUsage, DataTypeUsage]
|
||||
@dataclass
|
||||
class UnknownTypeUsage:
|
||||
pass
|
||||
|
||||
|
||||
TypeUsage = Union[FunctionTypeUsage, DataTypeUsage, UnknownTypeUsage]
|
||||
|
||||
|
||||
class Operator(enum.Enum):
|
||||
@@ -59,14 +59,14 @@ class Operator(enum.Enum):
|
||||
@dataclass
|
||||
class LiteralInt:
|
||||
value: int
|
||||
type: Optional[TypeUsage]
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCall:
|
||||
source: "Expression"
|
||||
arguments: List["Expression"]
|
||||
type: Optional[TypeUsage]
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -74,25 +74,25 @@ class Operation:
|
||||
left: "Expression"
|
||||
op: Operator
|
||||
right: "Expression"
|
||||
type: Optional[TypeUsage]
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariableUsage:
|
||||
name: Identifier
|
||||
type: Optional[TypeUsage]
|
||||
name: str
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Expression:
|
||||
expression: Union[LiteralInt, FunctionCall, VariableUsage, Operation]
|
||||
type: Optional[TypeUsage]
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
class LetStatement:
|
||||
variable_name: Identifier
|
||||
type: Optional[TypeUsage]
|
||||
variable_name: str
|
||||
type: TypeUsage
|
||||
expression: Expression
|
||||
|
||||
|
||||
@@ -102,18 +102,18 @@ Statement = Union[LetStatement, Expression]
|
||||
@dataclass
|
||||
class Block:
|
||||
statements: List[Statement]
|
||||
type: Optional[TypeUsage]
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariableDeclaration:
|
||||
name: Identifier
|
||||
name: str
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: Identifier
|
||||
name: str
|
||||
arguments: List[VariableDeclaration]
|
||||
block: Block
|
||||
return_type: TypeUsage
|
||||
@@ -191,6 +191,8 @@ boring_grammar = r"""
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
next_sub_id = 0
|
||||
|
||||
|
||||
class TreeToBoring(Transformer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -210,46 +212,46 @@ class TreeToBoring(Transformer):
|
||||
|
||||
def literal_int(self, n) -> LiteralInt:
|
||||
(n,) = n
|
||||
return LiteralInt(value=int(n), type=None)
|
||||
return LiteralInt(value=int(n), type=UnknownTypeUsage())
|
||||
|
||||
def identifier(self, i) -> Identifier:
|
||||
def identifier(self, i) -> str:
|
||||
(i,) = i
|
||||
return Identifier(name=str(i))
|
||||
return str(i)
|
||||
|
||||
def variable_usage(self, variable) -> VariableUsage:
|
||||
(variable,) = variable
|
||||
return VariableUsage(name=variable, type=None)
|
||||
return VariableUsage(name=variable, type=UnknownTypeUsage())
|
||||
|
||||
def function_call(self, call) -> FunctionCall:
|
||||
return FunctionCall(source=call[0], arguments=call[1:], type=None)
|
||||
return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage())
|
||||
|
||||
def add_expression(self, ae) -> Operation:
|
||||
return Operation(left=ae[0], op=ae[1], right=ae[2], type=None)
|
||||
return Operation(left=ae[0], op=ae[1], right=ae[2], type=UnknownTypeUsage())
|
||||
|
||||
def sub_expression(self, se) -> Operation:
|
||||
return Operation(left=se[0], op=se[1], right=se[2], type=None)
|
||||
return Operation(left=se[0], op=se[1], right=se[2], type=UnknownTypeUsage())
|
||||
|
||||
def mult_expression(self, se) -> Operation:
|
||||
return Operation(left=se[0], op=se[1], right=se[2], type=None)
|
||||
return Operation(left=se[0], op=se[1], right=se[2], type=UnknownTypeUsage())
|
||||
|
||||
def div_expression(self, se) -> Operation:
|
||||
return Operation(left=se[0], op=se[1], right=se[2], type=None)
|
||||
return Operation(left=se[0], op=se[1], right=se[2], type=UnknownTypeUsage())
|
||||
|
||||
def expression(self, exp) -> Expression:
|
||||
(exp,) = exp
|
||||
if isinstance(exp, Expression):
|
||||
return exp
|
||||
return Expression(expression=exp, type=None)
|
||||
return Expression(expression=exp, type=UnknownTypeUsage())
|
||||
|
||||
def factor(self, factor) -> Expression:
|
||||
(factor,) = factor
|
||||
if isinstance(factor, Expression):
|
||||
return factor
|
||||
return Expression(expression=factor, type=None)
|
||||
return Expression(expression=factor, type=UnknownTypeUsage())
|
||||
|
||||
def term(self, term) -> Expression:
|
||||
(term,) = term
|
||||
return Expression(expression=term, type=None)
|
||||
return Expression(expression=term, type=UnknownTypeUsage())
|
||||
|
||||
def let_statement(self, let_statement) -> LetStatement:
|
||||
if len(let_statement) == 3:
|
||||
@@ -262,7 +264,7 @@ class TreeToBoring(Transformer):
|
||||
(variable_name, expression) = let_statement
|
||||
return LetStatement(
|
||||
variable_name=variable_name,
|
||||
type=None,
|
||||
type=UnknownTypeUsage(),
|
||||
expression=expression,
|
||||
)
|
||||
|
||||
@@ -271,7 +273,7 @@ class TreeToBoring(Transformer):
|
||||
return statement
|
||||
|
||||
def block(self, block) -> Block:
|
||||
return Block(statements=block, type=None)
|
||||
return Block(statements=block, type=UnknownTypeUsage())
|
||||
|
||||
def data_type(self, name) -> TypeUsage:
|
||||
(name,) = name
|
||||
@@ -280,7 +282,7 @@ class TreeToBoring(Transformer):
|
||||
def function_type(self, type_usage) -> TypeUsage:
|
||||
return FunctionTypeUsage(
|
||||
arguments=type_usage,
|
||||
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
|
||||
return_type=DataTypeUsage(name=UNIT_TYPE),
|
||||
)
|
||||
|
||||
def function_type_with_return(self, type_usage) -> TypeUsage:
|
||||
@@ -298,11 +300,11 @@ class TreeToBoring(Transformer):
|
||||
return Function(
|
||||
name=function[0],
|
||||
arguments=function[1:-1],
|
||||
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
|
||||
return_type=DataTypeUsage(name=UNIT_TYPE),
|
||||
block=function[-1],
|
||||
type=FunctionTypeUsage(
|
||||
arguments=[arg.type for arg in function[1:-1]],
|
||||
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
|
||||
return_type=DataTypeUsage(name=UNIT_TYPE),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
|
||||
from boring import parse
|
||||
|
||||
|
||||
@dataclass
|
||||
class EqualityTypeComparison:
|
||||
from_id: Optional[str]
|
||||
to_id: str
|
||||
type_usage: Optional[parse.TypeUsage]
|
||||
# constraints: List[Constraint]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallTypeComparison:
|
||||
from_id: str
|
||||
to_id: str
|
||||
type_usage: Optional[parse.TypeUsage]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionArgumentTypeComparison:
|
||||
from_id: str
|
||||
to_id: str
|
||||
argument_id: int
|
||||
type_usage: Optional[parse.TypeUsage]
|
||||
|
||||
|
||||
TypeComparison = Union[
|
||||
EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison
|
||||
]
|
||||
Environment = Dict[str, str]
|
||||
|
||||
|
||||
class TypeCheckTableBuilder:
|
||||
def with_module(
|
||||
self, env: Environment, table: List[TypeComparison], module: parse.Module
|
||||
):
|
||||
for function in module.functions:
|
||||
env[function.name.name] = function.id
|
||||
type_usage = parse.FunctionTypeUsage(
|
||||
arguments=[arg.type for arg in function.arguments],
|
||||
return_type=function.return_type,
|
||||
)
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=None, to_id=function.id, type_usage=type_usage
|
||||
)
|
||||
)
|
||||
for function in module.functions:
|
||||
self.with_function(env, table, function)
|
||||
|
||||
def with_function(
|
||||
self, env: Environment, table: List[TypeComparison], function: parse.Function
|
||||
):
|
||||
function_env = env.copy()
|
||||
for argument in function.arguments:
|
||||
function_env[argument.name.name] = argument.id
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=None, to_id=argument.id, type_usage=argument.type
|
||||
)
|
||||
)
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=None, to_id=function.block.id, type_usage=function.return_type
|
||||
)
|
||||
)
|
||||
self.with_block(function_env, table, function.block)
|
||||
|
||||
# Skip variable VariableDeclaration
|
||||
|
||||
def with_block(
|
||||
self, env: Environment, table: List[TypeComparison], block: parse.Block
|
||||
):
|
||||
block_env = env.copy()
|
||||
# if parent is void, must be statement
|
||||
# if parent is type, must be expression
|
||||
if isinstance(block.statements[-1], parse.Expression):
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=block.statements[-1].id, to_id=block.id, type_usage=None
|
||||
)
|
||||
)
|
||||
else:
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=None,
|
||||
to_id=block.id,
|
||||
type_usage=parse.DataTypeUsage(
|
||||
name=parse.Identifier(name=parse.UNIT_TYPE)
|
||||
),
|
||||
)
|
||||
)
|
||||
for statement in block.statements:
|
||||
print(statement)
|
||||
self.with_statement(block_env, table, statement)
|
||||
|
||||
def with_statement(
|
||||
self, env: Environment, table: List[TypeComparison], statement: parse.Statement
|
||||
):
|
||||
if isinstance(statement, parse.LetStatement):
|
||||
self.with_let_statement(env, table, statement)
|
||||
elif isinstance(statement, parse.Expression): # expression
|
||||
self.with_expression(env, table, statement)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def with_let_statement(
|
||||
self,
|
||||
env: Environment,
|
||||
table: List[TypeComparison],
|
||||
let_statement: parse.LetStatement,
|
||||
):
|
||||
env[let_statement.variable_name.name] = let_statement.id
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=let_statement.expression.id,
|
||||
to_id=let_statement.id,
|
||||
type_usage=let_statement.type,
|
||||
)
|
||||
)
|
||||
self.with_expression(env, table, let_statement.expression)
|
||||
|
||||
def with_expression(
|
||||
self,
|
||||
env: Environment,
|
||||
table: List[TypeComparison],
|
||||
expression: parse.Expression,
|
||||
):
|
||||
if isinstance(expression.expression, parse.LiteralInt):
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=expression.expression.id,
|
||||
to_id=expression.id,
|
||||
type_usage=None,
|
||||
)
|
||||
)
|
||||
self.with_literal_int(env, table, expression.expression)
|
||||
elif isinstance(expression.expression, parse.FunctionCall):
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=expression.expression.id,
|
||||
to_id=expression.id,
|
||||
type_usage=None,
|
||||
)
|
||||
)
|
||||
self.with_function_call(env, table, expression.expression)
|
||||
elif isinstance(expression.expression, parse.VariableUsage):
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=expression.expression.id,
|
||||
to_id=expression.id,
|
||||
type_usage=None,
|
||||
)
|
||||
)
|
||||
self.with_variable_usage(env, table, expression.expression)
|
||||
elif isinstance(expression.expression, parse.Operation):
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=expression.expression.id,
|
||||
to_id=expression.id,
|
||||
type_usage=None,
|
||||
)
|
||||
)
|
||||
self.with_operation(env, table, expression.expression)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def with_variable_usage(
|
||||
self,
|
||||
env: Environment,
|
||||
table: List[TypeComparison],
|
||||
variable_usage: parse.VariableUsage,
|
||||
):
|
||||
print("%%%%%%%%%%%%%%%%%%%%%")
|
||||
print(env[variable_usage.name.name])
|
||||
print(variable_usage.id)
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=env[variable_usage.name.name],
|
||||
to_id=variable_usage.id,
|
||||
type_usage=None,
|
||||
)
|
||||
)
|
||||
|
||||
def with_operation(
|
||||
self, env: Environment, table: List[TypeComparison], operation: parse.Operation
|
||||
):
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=operation.left.id, to_id=operation.id, type_usage=None
|
||||
)
|
||||
)
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=operation.right.id, to_id=operation.id, type_usage=None
|
||||
)
|
||||
)
|
||||
self.with_expression(env, table, operation.left)
|
||||
self.with_expression(env, table, operation.right)
|
||||
|
||||
def with_function_call(
|
||||
self,
|
||||
env: Environment,
|
||||
table: List[TypeComparison],
|
||||
function_call: parse.FunctionCall,
|
||||
):
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
to_id=function_call.id, from_id=function_call.source.id, type_usage=None
|
||||
)
|
||||
)
|
||||
self.with_expression(env, table, function_call.source)
|
||||
|
||||
for i, argument in enumerate(function_call.arguments):
|
||||
# table.append(
|
||||
# FunctionArgumentTypeComparison(from_id=env[function_call.name.name], to_id=function_call.id, type_usage=None)
|
||||
# )
|
||||
# FunctionArgumentTypeComparison
|
||||
self.with_expression(env, table, argument)
|
||||
|
||||
def with_literal_int(
|
||||
self,
|
||||
env: Environment,
|
||||
table: List[TypeComparison],
|
||||
literal_int: parse.LiteralInt,
|
||||
):
|
||||
table.append(
|
||||
EqualityTypeComparison(
|
||||
from_id=None,
|
||||
to_id=literal_int.id,
|
||||
type_usage=parse.DataTypeUsage(
|
||||
name=parse.Identifier(name="u32"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def check_types(table: List[TypeComparison]):
|
||||
found = True
|
||||
while found:
|
||||
found = False
|
||||
for entry in table:
|
||||
for other in table:
|
||||
if other.to_id == entry.to_id:
|
||||
if other.type_usage is None and entry.type_usage is None:
|
||||
pass
|
||||
elif other.type_usage is None and entry.type_usage is not None:
|
||||
other.type_usage = entry.type_usage
|
||||
found = True
|
||||
elif other.type_usage is not None and entry.type_usage is None:
|
||||
entry.type_usage = other.type_usage
|
||||
found = True
|
||||
else:
|
||||
assert entry.type_usage == other.type_usage
|
||||
if other.from_id == entry.to_id:
|
||||
# let a = || {4} entry
|
||||
# let b = a() other
|
||||
if other.type_usage is None and entry.type_usage is None:
|
||||
pass
|
||||
elif other.type_usage is None and entry.type_usage is not None:
|
||||
if isinstance(other, EqualityTypeComparison):
|
||||
other.type_usage = entry.type_usage
|
||||
found = True
|
||||
elif isinstance(other, FunctionCallTypeComparison):
|
||||
assert isinstance(
|
||||
entry.type_usage, parse.FunctionTypeUsage
|
||||
), "non function called"
|
||||
other.type_usage = entry.type_usage.return_type
|
||||
found = True
|
||||
elif other.type_usage is not None and entry.type_usage is None:
|
||||
if isinstance(other, EqualityTypeComparison):
|
||||
entry.type_usage = other.type_usage
|
||||
found = True
|
||||
elif isinstance(other, FunctionCallTypeComparison):
|
||||
pass # can't reverse a function
|
||||
else:
|
||||
if isinstance(other, EqualityTypeComparison):
|
||||
assert other.type_usage == entry.type_usage
|
||||
elif isinstance(other, FunctionCallTypeComparison):
|
||||
assert isinstance(
|
||||
entry.type_usage, parse.FunctionTypeUsage
|
||||
), "non function called"
|
||||
assert other.type_usage == entry.type_usage.return_type
|
||||
|
||||
# if other.to_id == entry.from_id:
|
||||
# # let a = || {4} other
|
||||
# # let b = a() entry
|
||||
# if other.type_usage is None and entry.type_usage is None:
|
||||
# pass
|
||||
# elif other.type_usage is None and entry.type_usage is not None:
|
||||
# if isinstance(entry, EqualityTypeComparison):
|
||||
# other.type_usage = entry.type_usage
|
||||
# found = True
|
||||
# elif isinstance(entry, FunctionCallTypeComparison):
|
||||
# pass # can't reverse a function
|
||||
# elif other.type_usage is not None and entry.type_usage is None:
|
||||
# if isinstance(entry, EqualityTypeComparison):
|
||||
# entry.type_usage = other.type_usage
|
||||
# found = True
|
||||
# elif isinstance(entry, FunctionCallTypeComparison):
|
||||
# entry.type_usage = other.type_usage.return_type
|
||||
# found = True
|
||||
# if other.from_id == entry.from_id and entry.from_id is not None:
|
||||
# if other.type_usage is None and entry.type_usage is None:
|
||||
# pass
|
||||
# elif other.type_usage is None and entry.type_usage is not None:
|
||||
# other.type_usage = entry.type_usage
|
||||
# found = True
|
||||
# elif other.type_usage is not None and entry.type_usage is None:
|
||||
# entry.type_usage = other.type_usage
|
||||
# found = True
|
||||
# else:
|
||||
# assert entry.type_usage == other.type_usage, f"{entry.from_id} {other.from_id} {entry.type_usage} == {other.type_usage}"
|
||||
@@ -9,10 +9,53 @@ Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration
|
||||
Environment = Dict[str, Identified]
|
||||
|
||||
|
||||
def unify(first, second) -> bool:
|
||||
result, changed = type_compare(first.type, second.type)
|
||||
first.type = result
|
||||
second.type = result
|
||||
return changed
|
||||
|
||||
|
||||
def type_compare(first, second) -> (parse.TypeUsage, bool):
|
||||
print(first, second)
|
||||
if isinstance(first, parse.UnknownTypeUsage):
|
||||
if not isinstance(second, parse.UnknownTypeUsage):
|
||||
return second, True
|
||||
else:
|
||||
return parse.UnknownTypeUsage(), False
|
||||
else:
|
||||
if isinstance(second, parse.UnknownTypeUsage):
|
||||
return first, True
|
||||
else:
|
||||
if isinstance(first, parse.DataTypeUsage) and isinstance(
|
||||
second, parse.DataTypeUsage
|
||||
):
|
||||
assert second == first
|
||||
return first, False
|
||||
elif isinstance(first, parse.FunctionTypeUsage) and isinstance(
|
||||
second, parse.FunctionTypeUsage
|
||||
):
|
||||
return_type, changed = type_compare(
|
||||
first.return_type, second.return_type
|
||||
)
|
||||
arguments = []
|
||||
assert len(first.arguments) == len(second.arguments)
|
||||
for first_arg, second_arg in zip(first.arguments, second.arguments):
|
||||
argument_type, argument_changed = type_compare(
|
||||
first_arg, second_arg
|
||||
)
|
||||
arguments.append(argument_type)
|
||||
if argument_changed:
|
||||
changed = True
|
||||
return parse.FunctionTypeUsage(arguments, return_type), changed
|
||||
else:
|
||||
assert False, f"mismatched types {first}, {second}"
|
||||
|
||||
|
||||
class TypeChecker:
|
||||
def with_module(self, env: Environment, module: parse.Module) -> bool:
|
||||
for function in module.functions:
|
||||
env[function.name.name] = function
|
||||
env[function.name] = function
|
||||
found = False
|
||||
for function in module.functions:
|
||||
if self.with_function(env, function):
|
||||
@@ -22,10 +65,15 @@ class TypeChecker:
|
||||
def with_function(self, env: Environment, function: parse.Function) -> bool:
|
||||
function_env = env.copy()
|
||||
for argument in function.arguments:
|
||||
function_env[argument.name.name] = argument
|
||||
function_env[argument.name] = argument
|
||||
assert isinstance(function.type, parse.FunctionTypeUsage)
|
||||
function.block.type = function.type.return_type
|
||||
return self.with_block(function_env, function.block)
|
||||
|
||||
type, changed = type_compare(function.block.type, function.type.return_type)
|
||||
function.block.type = type
|
||||
function.type.return_type = type
|
||||
if self.with_block(function_env, function.block):
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
# Skip variable VariableDeclaration
|
||||
|
||||
@@ -33,30 +81,26 @@ class TypeChecker:
|
||||
block_env = env.copy()
|
||||
# if parent is void, must be statement
|
||||
# if parent is type, must be expression
|
||||
found = False
|
||||
changed = False
|
||||
final = block.statements[-1]
|
||||
if isinstance(final, parse.LetStatement):
|
||||
if block.type is None:
|
||||
if isinstance(block.type, parse.UnknownTypeUsage):
|
||||
found = True
|
||||
block.type = parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))
|
||||
block.type = parse.DataTypeUsage(
|
||||
name=parse.Identifier(name=parse.UNIT_TYPE)
|
||||
)
|
||||
else:
|
||||
assert block.type == parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))
|
||||
assert block.type == parse.DataTypeUsage(
|
||||
name=parse.Identifier(name=parse.UNIT_TYPE)
|
||||
)
|
||||
elif isinstance(final, parse.Expression):
|
||||
if block.type is None:
|
||||
if final.type is not None:
|
||||
found = True
|
||||
block.type = final.type
|
||||
else:
|
||||
if final.type is None:
|
||||
found = True
|
||||
final.type = block.type
|
||||
else:
|
||||
assert final.type == block.type
|
||||
if unify(final, block):
|
||||
changed = True
|
||||
|
||||
for statement in block.statements:
|
||||
if self.with_statement(block_env, statement):
|
||||
found = True
|
||||
return found
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def with_statement(self, env: Environment, statement: parse.Statement) -> bool:
|
||||
if isinstance(statement, parse.LetStatement):
|
||||
@@ -70,136 +114,93 @@ class TypeChecker:
|
||||
self, env: Environment, let_statement: parse.LetStatement
|
||||
) -> bool:
|
||||
found = False
|
||||
env[let_statement.variable_name.name] = let_statement
|
||||
if let_statement.type is None:
|
||||
if let_statement.expression.type is not None:
|
||||
let_statement.type = let_statement.expression.type
|
||||
found = True
|
||||
else:
|
||||
if let_statement.expression.type is None:
|
||||
let_statement.expression.type = let_statement.type
|
||||
found = True
|
||||
else:
|
||||
assert let_statement.expression.type == let_statement.type
|
||||
env[let_statement.variable_name] = let_statement
|
||||
changed = unify(let_statement, let_statement.expression)
|
||||
if self.with_expression(env, let_statement.expression):
|
||||
found = True
|
||||
return found
|
||||
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def with_expression(self, env: Environment, expression: parse.Expression) -> bool:
|
||||
subexpression = expression.expression
|
||||
found = False
|
||||
# generic to all types
|
||||
if expression.type is None:
|
||||
if subexpression.type is not None:
|
||||
expression.type = subexpression.type
|
||||
found = True
|
||||
else:
|
||||
if subexpression.type is None:
|
||||
subexpression.type = expression.type
|
||||
found = True
|
||||
else:
|
||||
assert subexpression.type == expression.type
|
||||
changed = unify(subexpression, expression)
|
||||
|
||||
if isinstance(subexpression, parse.LiteralInt):
|
||||
print(f"fooooo {expression.type}, {subexpression.type}")
|
||||
if self.with_literal_int(env, subexpression):
|
||||
found = True
|
||||
return found
|
||||
changed = True
|
||||
return changed
|
||||
if isinstance(subexpression, parse.FunctionCall):
|
||||
if self.with_function_call(env, subexpression):
|
||||
found = True
|
||||
return found
|
||||
changed = True
|
||||
return changed
|
||||
if isinstance(subexpression, parse.VariableUsage):
|
||||
if self.with_variable_usage(env, subexpression):
|
||||
found = True
|
||||
return found
|
||||
changed = True
|
||||
return changed
|
||||
if isinstance(subexpression, parse.Operation):
|
||||
if self.with_operation(env, subexpression):
|
||||
found = True
|
||||
return found
|
||||
changed = True
|
||||
return changed
|
||||
assert False
|
||||
|
||||
def with_variable_usage(
|
||||
self, env: Environment, variable_usage: parse.VariableUsage
|
||||
) -> bool:
|
||||
found = False
|
||||
variable = env[variable_usage.name.name]
|
||||
if variable_usage.type is None:
|
||||
if variable.type is not None:
|
||||
variable_usage.type = variable.type
|
||||
found = True
|
||||
else:
|
||||
if variable.type is None:
|
||||
# print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
|
||||
# print(f"{variable.name} {variable.type}")
|
||||
variable.type = variable_usage.type
|
||||
found = True
|
||||
else:
|
||||
assert variable.type == variable_usage.type
|
||||
return found
|
||||
return unify(variable_usage, env[variable_usage.name])
|
||||
|
||||
def with_operation(self, env: Environment, operation: parse.Operation) -> bool:
|
||||
found = False
|
||||
if operation.type is None:
|
||||
if operation.left.type is not None:
|
||||
operation.type = operation.left.type
|
||||
found = True
|
||||
else:
|
||||
if operation.left.type is None:
|
||||
operation.left.type = operation.type
|
||||
found = True
|
||||
else:
|
||||
assert operation.left.type == operation.type
|
||||
if operation.type is None:
|
||||
if operation.right.type is not None:
|
||||
operation.type = operation.right.type
|
||||
found = True
|
||||
else:
|
||||
if operation.right.type is None:
|
||||
operation.right.type = operation.type
|
||||
found = True
|
||||
else:
|
||||
assert operation.right.type == operation.type
|
||||
changed = False
|
||||
if unify(operation, operation.left):
|
||||
changed = True
|
||||
if unify(operation, operation.right):
|
||||
changed = True
|
||||
if self.with_expression(env, operation.left):
|
||||
found = True
|
||||
changed = True
|
||||
if self.with_expression(env, operation.right):
|
||||
found = True
|
||||
return found
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def with_function_call(
|
||||
self, env: Environment, function_call: parse.FunctionCall
|
||||
) -> bool:
|
||||
found = False
|
||||
if function_call.type is None:
|
||||
if function_call.source.type is not None:
|
||||
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
|
||||
found = True
|
||||
function_call.type = function_call.source.type.return_type
|
||||
else:
|
||||
if function_call.source.type is not None:
|
||||
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
|
||||
assert function_call.type == function_call.source.type.return_type
|
||||
changed = False
|
||||
if isinstance(function_call.source.type, parse.UnknownTypeUsage):
|
||||
function_call.source.type = parse.FunctionTypeUsage(
|
||||
arguments=[parse.UnknownTypeUsage()] * len(function_call.arguments),
|
||||
return_type=parse.UnknownTypeUsage(),
|
||||
)
|
||||
changed = True
|
||||
if self.with_expression(env, function_call.source):
|
||||
found = True
|
||||
|
||||
if function_call.source.type is not None:
|
||||
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
|
||||
assert len(function_call.arguments) == len(function_call.source.type.arguments)
|
||||
for (argument, type_argument) in zip(function_call.arguments, function_call.source.type.arguments):
|
||||
if argument.type is None:
|
||||
argument.type = type_argument
|
||||
found = True
|
||||
else:
|
||||
assert argument.type == type_argument
|
||||
|
||||
changed = True
|
||||
for argument in function_call.arguments:
|
||||
if self.with_expression(env, argument):
|
||||
found = True
|
||||
return found
|
||||
changed = True
|
||||
|
||||
return_type, return_changed = type_compare(
|
||||
function_call.type, function_call.source.type.return_type
|
||||
)
|
||||
function_call.type = return_type
|
||||
function_call.source.type.return_type = return_type
|
||||
if return_changed:
|
||||
changed = True
|
||||
|
||||
for argument, argument_type in zip(
|
||||
function_call.arguments, function_call.source.type.arguments
|
||||
):
|
||||
argument_out_type, argument_changed = type_compare(
|
||||
argument.type, function_call.source.type.return_type
|
||||
)
|
||||
argument.type = argument_out_type
|
||||
function_call.source.type.return_type = argument_out_type
|
||||
if argument_changed:
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def with_literal_int(self, env: Environment, literal_int: parse.LiteralInt) -> bool:
|
||||
ints = [parse.DataTypeUsage(name=parse.Identifier(name=name)) for name in ["Int8", "Int16", "Int32", "Int64", "Int128", "u32"]]
|
||||
if literal_int.type is not None:
|
||||
ints = [
|
||||
parse.DataTypeUsage(name=name)
|
||||
for name in ["I8", "I16", "I32", "I64", "I128"]
|
||||
]
|
||||
if not isinstance(literal_int.type, parse.UnknownTypeUsage):
|
||||
assert literal_int.type in ints, f"{literal_int.type}"
|
||||
return False
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
fn add(a: u32, b: u32): u32 {
|
||||
fn add(a: I32, b: I32): I32 {
|
||||
let foo = 4;
|
||||
a + b + foo
|
||||
}
|
||||
|
||||
fn subtract(a: u32, b: u32): u32 {
|
||||
fn subtract(a: I32, b: I32): I32 {
|
||||
a - b
|
||||
}
|
||||
|
||||
fn main(): u32 {
|
||||
fn main(): I32 {
|
||||
add(4, subtract(5, 2))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user