added type checking

This commit is contained in:
Andrew Segavac
2021-05-12 06:40:11 -06:00
parent 9d9d42ebd5
commit cb30ad7040
5 changed files with 355 additions and 163 deletions

View File

@@ -1,24 +1,29 @@
import sys import sys
from typing import List from typing import List
from boring.parse import boring_parser, TreeToBoring, pretty_print from boring.parse import boring_parser, TreeToBoring, pretty_print
from boring.type_checker import TypeCheckTableBuilder, TypeComparison, check_types from boring.type_checking import TypeChecker
if __name__ == "__main__": if __name__ == "__main__":
with open(sys.argv[1]) as f: with open(sys.argv[1]) as f:
tree = boring_parser.parse(f.read()) tree = boring_parser.parse(f.read())
# print(tree) # print(tree)
result = TreeToBoring().transform(tree) result = TreeToBoring().transform(tree)
# pretty_print(result)
type_checker = TypeChecker()
while type_checker.with_module({}, result):
print('loop')
# type_checker.with_module({}, result)
pretty_print(result) pretty_print(result)
tctb = TypeCheckTableBuilder() # tctb = TypeCheckTableBuilder()
table: List[TypeComparison] = [] # table: List[TypeComparison] = []
tctb.with_module({}, table, result) # tctb.with_module({}, table, result)
for e in table: # for e in table:
print(e) # print(e)
print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^') # print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
check_types(table) # check_types(table)
print('vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv') # print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
for e in table: # for e in table:
print(e) # print(e)
# None, Some # None, Some
# None skip, set # None skip, set

View File

@@ -35,7 +35,9 @@ class Identifier:
@dataclass @dataclass
class FunctionTypeUsage: class FunctionTypeUsage:
arguments: List["TypeUsage"] # Specified if it is a function, this is how you tell if it's a function arguments: List[
"TypeUsage"
] # Specified if it is a function, this is how you tell if it's a function
return_type: "TypeUsage" return_type: "TypeUsage"
@@ -56,40 +58,39 @@ class Operator(enum.Enum):
@dataclass @dataclass
class LiteralInt: class LiteralInt:
id: str
value: int value: int
type: Optional[TypeUsage]
@dataclass @dataclass
class FunctionCall: class FunctionCall:
id: str source: "Expression"
source: Expression
arguments: List["Expression"] arguments: List["Expression"]
type: Optional[TypeUsage]
@dataclass @dataclass
class Operation: class Operation:
id: str
left: "Expression" left: "Expression"
op: Operator op: Operator
right: "Expression" right: "Expression"
type: Optional[TypeUsage]
@dataclass @dataclass
class VariableUsage: class VariableUsage:
id: str
name: Identifier name: Identifier
type: Optional[TypeUsage]
@dataclass @dataclass
class Expression: class Expression:
id: str
expression: Union[LiteralInt, FunctionCall, VariableUsage, Operation] expression: Union[LiteralInt, FunctionCall, VariableUsage, Operation]
type: Optional[TypeUsage]
@dataclass @dataclass
class LetStatement: class LetStatement:
id: str
variable_name: Identifier variable_name: Identifier
type: Optional[TypeUsage] type: Optional[TypeUsage]
expression: Expression expression: Expression
@@ -100,29 +101,27 @@ Statement = Union[LetStatement, Expression]
@dataclass @dataclass
class Block: class Block:
id: str
statements: List[Statement] statements: List[Statement]
type: Optional[TypeUsage]
@dataclass @dataclass
class VariableDeclaration: class VariableDeclaration:
id: str
name: Identifier name: Identifier
type: TypeUsage type: TypeUsage
@dataclass @dataclass
class Function: class Function:
id: str
name: Identifier name: Identifier
arguments: List[VariableDeclaration] arguments: List[VariableDeclaration]
block: Block block: Block
return_type: TypeUsage return_type: TypeUsage
type: TypeUsage
@dataclass @dataclass
class Module: class Module:
id: str
functions: List[Function] functions: List[Function]
@@ -196,7 +195,6 @@ boring_grammar = r"""
class TreeToBoring(Transformer): class TreeToBoring(Transformer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.id = 0
def plus(self, p) -> Operator: def plus(self, p) -> Operator:
return Operator.plus return Operator.plus
@@ -212,8 +210,7 @@ class TreeToBoring(Transformer):
def literal_int(self, n) -> LiteralInt: def literal_int(self, n) -> LiteralInt:
(n,) = n (n,) = n
self.id += 1 return LiteralInt(value=int(n), type=None)
return LiteralInt(value=int(n), id=str(self.id))
def identifier(self, i) -> Identifier: def identifier(self, i) -> Identifier:
(i,) = i (i,) = i
@@ -221,64 +218,52 @@ class TreeToBoring(Transformer):
def variable_usage(self, variable) -> VariableUsage: def variable_usage(self, variable) -> VariableUsage:
(variable,) = variable (variable,) = variable
self.id += 1 return VariableUsage(name=variable, type=None)
return VariableUsage(name=variable, id=str(self.id))
def function_call(self, call) -> FunctionCall: def function_call(self, call) -> FunctionCall:
self.id += 1 return FunctionCall(source=call[0], arguments=call[1:], type=None)
return FunctionCall(source=call[0], arguments=call[1:], id=str(self.id))
def add_expression(self, ae) -> Operation: def add_expression(self, ae) -> Operation:
self.id += 1 return Operation(left=ae[0], op=ae[1], right=ae[2], type=None)
return Operation(left=ae[0], op=ae[1], right=ae[2], id=str(self.id))
def sub_expression(self, se) -> Operation: def sub_expression(self, se) -> Operation:
self.id += 1 return Operation(left=se[0], op=se[1], right=se[2], type=None)
return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id))
def mult_expression(self, se) -> Operation: def mult_expression(self, se) -> Operation:
self.id += 1 return Operation(left=se[0], op=se[1], right=se[2], type=None)
return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id))
def div_expression(self, se) -> Operation: def div_expression(self, se) -> Operation:
self.id += 1 return Operation(left=se[0], op=se[1], right=se[2], type=None)
return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id))
def expression(self, exp) -> Expression: def expression(self, exp) -> Expression:
(exp,) = exp (exp,) = exp
if isinstance(exp, Expression): if isinstance(exp, Expression):
return exp return exp
self.id += 1 return Expression(expression=exp, type=None)
return Expression(expression=exp, id=str(self.id))
def factor(self, factor) -> Expression: def factor(self, factor) -> Expression:
(factor,) = factor (factor,) = factor
if isinstance(factor, Expression): if isinstance(factor, Expression):
return factor return factor
self.id += 1 return Expression(expression=factor, type=None)
return Expression(expression=factor, id=str(self.id))
def term(self, term) -> Expression: def term(self, term) -> Expression:
(term,) = term (term,) = term
self.id += 1 return Expression(expression=term, type=None)
return Expression(expression=term, id=str(self.id))
def let_statement(self, let_statement) -> LetStatement: def let_statement(self, let_statement) -> LetStatement:
self.id += 1
if len(let_statement) == 3: if len(let_statement) == 3:
(variable_name, type_usage, expression) = let_statement (variable_name, type_usage, expression) = let_statement
return LetStatement( return LetStatement(
variable_name=variable_name, variable_name=variable_name,
type=type_usage, type=type_usage,
expression=expression, expression=expression,
id=str(self.id),
) )
(variable_name, expression) = let_statement (variable_name, expression) = let_statement
return LetStatement( return LetStatement(
variable_name=variable_name, variable_name=variable_name,
type=None, type=None,
expression=expression, expression=expression,
id=str(self.id),
) )
def statement(self, statement): def statement(self, statement):
@@ -286,23 +271,20 @@ class TreeToBoring(Transformer):
return statement return statement
def block(self, block) -> Block: def block(self, block) -> Block:
self.id += 1 return Block(statements=block, type=None)
return Block(statements=block, id=str(self.id))
def data_type(self, name) -> TypeUsage: def data_type(self, name) -> TypeUsage:
(name,) = name (name,) = name
self.id += 1
return DataTypeUsage(name=name) return DataTypeUsage(name=name)
def function_type(self, type_usage) -> TypeUsage: def function_type(self, type_usage) -> TypeUsage:
self.id += 1 return FunctionTypeUsage(
return FunctionTypeUsage(arguments=type_usage, return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE))) arguments=type_usage,
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
)
def function_type_with_return(self, type_usage) -> TypeUsage: def function_type_with_return(self, type_usage) -> TypeUsage:
self.id += 1 return FunctionTypeUsage(arguments=type_usage[0:-1], return_type=type_usage[-1])
return FunctionTypeUsage(
arguments=type_usage[0:-1], return_type=type_usage[-1]
)
def type_usage(self, type_usage): def type_usage(self, type_usage):
(type_usage,) = type_usage (type_usage,) = type_usage
@@ -310,27 +292,29 @@ class TreeToBoring(Transformer):
def variable_declaration(self, identifier) -> VariableDeclaration: def variable_declaration(self, identifier) -> VariableDeclaration:
(identifier, type_usage) = identifier (identifier, type_usage) = identifier
self.id += 1 return VariableDeclaration(name=identifier, type=type_usage)
return VariableDeclaration(name=identifier, type=type_usage, id=str(self.id))
def function_without_return(self, function) -> Function: def function_without_return(self, function) -> Function:
self.id += 1
return Function( return Function(
id=str(self.id),
name=function[0], name=function[0],
arguments=function[1:-1], arguments=function[1:-1],
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)), return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
block=function[-1], block=function[-1],
type=FunctionTypeUsage(
arguments=[arg.type for arg in function[1:-1]],
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
),
) )
def function_with_return(self, function) -> Function: def function_with_return(self, function) -> Function:
self.id += 1
return Function( return Function(
id=str(self.id),
name=function[0], name=function[0],
arguments=function[1:-2], arguments=function[1:-2],
return_type=function[-2], return_type=function[-2],
block=function[-1], block=function[-1],
type=FunctionTypeUsage(
arguments=[arg.type for arg in function[1:-2]], return_type=function[-2]
),
) )
def function(self, function): def function(self, function):
@@ -338,8 +322,7 @@ class TreeToBoring(Transformer):
return function return function
def module(self, functions) -> Module: def module(self, functions) -> Module:
self.id += 1 return Module(functions=functions)
return Module(id=str(self.id), functions=functions)
boring_parser = Lark(boring_grammar, start="module", lexer="standard") boring_parser = Lark(boring_grammar, start="module", lexer="standard")

View File

@@ -19,6 +19,7 @@ class FunctionCallTypeComparison:
to_id: str to_id: str
type_usage: Optional[parse.TypeUsage] type_usage: Optional[parse.TypeUsage]
@dataclass @dataclass
class FunctionArgumentTypeComparison: class FunctionArgumentTypeComparison:
from_id: str from_id: str
@@ -27,7 +28,9 @@ class FunctionArgumentTypeComparison:
type_usage: Optional[parse.TypeUsage] type_usage: Optional[parse.TypeUsage]
TypeComparison = Union[EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison] TypeComparison = Union[
EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison
]
Environment = Dict[str, str] Environment = Dict[str, str]
@@ -42,7 +45,9 @@ class TypeCheckTableBuilder:
return_type=function.return_type, return_type=function.return_type,
) )
table.append( table.append(
EqualityTypeComparison(from_id=None, to_id=function.id, type_usage=type_usage) EqualityTypeComparison(
from_id=None, to_id=function.id, type_usage=type_usage
)
) )
for function in module.functions: for function in module.functions:
self.with_function(env, table, function) self.with_function(env, table, function)
@@ -54,10 +59,14 @@ class TypeCheckTableBuilder:
for argument in function.arguments: for argument in function.arguments:
function_env[argument.name.name] = argument.id function_env[argument.name.name] = argument.id
table.append( table.append(
EqualityTypeComparison(from_id=None, to_id=argument.id, type_usage=argument.type) EqualityTypeComparison(
from_id=None, to_id=argument.id, type_usage=argument.type
)
) )
table.append( table.append(
EqualityTypeComparison(from_id=None, to_id=function.block.id, type_usage=function.return_type) EqualityTypeComparison(
from_id=None, to_id=function.block.id, type_usage=function.return_type
)
) )
self.with_block(function_env, table, function.block) self.with_block(function_env, table, function.block)
@@ -71,11 +80,19 @@ class TypeCheckTableBuilder:
# if parent is type, must be expression # if parent is type, must be expression
if isinstance(block.statements[-1], parse.Expression): if isinstance(block.statements[-1], parse.Expression):
table.append( table.append(
EqualityTypeComparison(from_id=block.statements[-1].id, to_id=block.id, type_usage=None) EqualityTypeComparison(
from_id=block.statements[-1].id, to_id=block.id, type_usage=None
)
) )
else: else:
table.append( table.append(
EqualityTypeComparison(from_id=None, to_id=block.id, type_usage=parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))) EqualityTypeComparison(
from_id=None,
to_id=block.id,
type_usage=parse.DataTypeUsage(
name=parse.Identifier(name=parse.UNIT_TYPE)
),
)
) )
for statement in block.statements: for statement in block.statements:
print(statement) print(statement)
@@ -86,73 +103,115 @@ class TypeCheckTableBuilder:
): ):
if isinstance(statement, parse.LetStatement): if isinstance(statement, parse.LetStatement):
self.with_let_statement(env, table, statement) self.with_let_statement(env, table, statement)
elif isinstance(statement, parse.Expression): # expression elif isinstance(statement, parse.Expression): # expression
self.with_expression(env, table, statement) self.with_expression(env, table, statement)
else: else:
assert False assert False
def with_let_statement( def with_let_statement(
self, env: Environment, table: List[TypeComparison], let_statement: parse.LetStatement self,
env: Environment,
table: List[TypeComparison],
let_statement: parse.LetStatement,
): ):
env[let_statement.variable_name.name] = let_statement.id env[let_statement.variable_name.name] = let_statement.id
table.append( table.append(
EqualityTypeComparison(from_id=let_statement.expression.id, to_id=let_statement.id, type_usage=let_statement.type) EqualityTypeComparison(
from_id=let_statement.expression.id,
to_id=let_statement.id,
type_usage=let_statement.type,
)
) )
self.with_expression(env, table, let_statement.expression) self.with_expression(env, table, let_statement.expression)
def with_expression( def with_expression(
self, env: Environment, table: List[TypeComparison], expression: parse.Expression self,
env: Environment,
table: List[TypeComparison],
expression: parse.Expression,
): ):
if isinstance(expression.expression, parse.LiteralInt): if isinstance(expression.expression, parse.LiteralInt):
table.append( table.append(
EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) EqualityTypeComparison(
from_id=expression.expression.id,
to_id=expression.id,
type_usage=None,
)
) )
self.with_literal_int(env, table, expression.expression) self.with_literal_int(env, table, expression.expression)
elif isinstance(expression.expression, parse.FunctionCall): elif isinstance(expression.expression, parse.FunctionCall):
table.append( table.append(
EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) EqualityTypeComparison(
from_id=expression.expression.id,
to_id=expression.id,
type_usage=None,
)
) )
self.with_function_call(env, table, expression.expression) self.with_function_call(env, table, expression.expression)
elif isinstance(expression.expression, parse.VariableUsage): elif isinstance(expression.expression, parse.VariableUsage):
table.append( table.append(
EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) EqualityTypeComparison(
from_id=expression.expression.id,
to_id=expression.id,
type_usage=None,
)
) )
self.with_variable_usage(env, table, expression.expression) self.with_variable_usage(env, table, expression.expression)
elif isinstance(expression.expression, parse.Operation): elif isinstance(expression.expression, parse.Operation):
table.append( table.append(
EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) EqualityTypeComparison(
from_id=expression.expression.id,
to_id=expression.id,
type_usage=None,
)
) )
self.with_operation(env, table, expression.expression) self.with_operation(env, table, expression.expression)
else: else:
assert False assert False
def with_variable_usage( def with_variable_usage(
self, env: Environment, table: List[TypeComparison], variable_usage: parse.VariableUsage self,
env: Environment,
table: List[TypeComparison],
variable_usage: parse.VariableUsage,
): ):
print('%%%%%%%%%%%%%%%%%%%%%') print("%%%%%%%%%%%%%%%%%%%%%")
print(env[variable_usage.name.name]) print(env[variable_usage.name.name])
print(variable_usage.id) print(variable_usage.id)
table.append( table.append(
EqualityTypeComparison(from_id=env[variable_usage.name.name], to_id=variable_usage.id, type_usage=None) EqualityTypeComparison(
from_id=env[variable_usage.name.name],
to_id=variable_usage.id,
type_usage=None,
)
) )
def with_operation( def with_operation(
self, env: Environment, table: List[TypeComparison], operation: parse.Operation self, env: Environment, table: List[TypeComparison], operation: parse.Operation
): ):
table.append( table.append(
EqualityTypeComparison(from_id=operation.left.id, to_id=operation.id, type_usage=None) EqualityTypeComparison(
from_id=operation.left.id, to_id=operation.id, type_usage=None
)
) )
table.append( table.append(
EqualityTypeComparison(from_id=operation.right.id, to_id=operation.id, type_usage=None) EqualityTypeComparison(
from_id=operation.right.id, to_id=operation.id, type_usage=None
)
) )
self.with_expression(env, table, operation.left) self.with_expression(env, table, operation.left)
self.with_expression(env, table, operation.right) self.with_expression(env, table, operation.right)
def with_function_call( def with_function_call(
self, env: Environment, table: List[TypeComparison], function_call: parse.FunctionCall self,
env: Environment,
table: List[TypeComparison],
function_call: parse.FunctionCall,
): ):
table.append( table.append(
EqualityTypeComparison(to_id=function_call.id, from_id=function_call.source.id, type_usage=None) EqualityTypeComparison(
to_id=function_call.id, from_id=function_call.source.id, type_usage=None
)
) )
self.with_expression(env, table, function_call.source) self.with_expression(env, table, function_call.source)
@@ -163,14 +222,20 @@ class TypeCheckTableBuilder:
# FunctionArgumentTypeComparison # FunctionArgumentTypeComparison
self.with_expression(env, table, argument) self.with_expression(env, table, argument)
def with_literal_int( def with_literal_int(
self, env: Environment, table: List[TypeComparison], literal_int: parse.LiteralInt self,
env: Environment,
table: List[TypeComparison],
literal_int: parse.LiteralInt,
): ):
table.append( table.append(
EqualityTypeComparison(from_id=None, to_id=literal_int.id, type_usage=parse.DataTypeUsage( EqualityTypeComparison(
name=parse.Identifier(name="u32"), from_id=None,
)) to_id=literal_int.id,
type_usage=parse.DataTypeUsage(
name=parse.Identifier(name="u32"),
),
)
) )
@@ -201,7 +266,9 @@ def check_types(table: List[TypeComparison]):
other.type_usage = entry.type_usage other.type_usage = entry.type_usage
found = True found = True
elif isinstance(other, FunctionCallTypeComparison): elif isinstance(other, FunctionCallTypeComparison):
assert isinstance(entry.type_usage, parse.FunctionTypeUsage), 'non function called' assert isinstance(
entry.type_usage, parse.FunctionTypeUsage
), "non function called"
other.type_usage = entry.type_usage.return_type other.type_usage = entry.type_usage.return_type
found = True found = True
elif other.type_usage is not None and entry.type_usage is None: elif other.type_usage is not None and entry.type_usage is None:
@@ -209,20 +276,16 @@ def check_types(table: List[TypeComparison]):
entry.type_usage = other.type_usage entry.type_usage = other.type_usage
found = True found = True
elif isinstance(other, FunctionCallTypeComparison): elif isinstance(other, FunctionCallTypeComparison):
pass # can't reverse a function pass # can't reverse a function
else: else:
if isinstance(other, EqualityTypeComparison): if isinstance(other, EqualityTypeComparison):
assert other.type_usage == entry.type_usage assert other.type_usage == entry.type_usage
elif isinstance(other, FunctionCallTypeComparison): elif isinstance(other, FunctionCallTypeComparison):
assert isinstance(entry.type_usage, parse.FunctionTypeUsage), 'non function called' assert isinstance(
entry.type_usage, parse.FunctionTypeUsage
), "non function called"
assert other.type_usage == entry.type_usage.return_type assert other.type_usage == entry.type_usage.return_type
# if other.to_id == entry.from_id: # if other.to_id == entry.from_id:
# # let a = || {4} other # # let a = || {4} other
# # let b = a() entry # # let b = a() entry

205
boring/type_checking.py Normal file
View File

@@ -0,0 +1,205 @@
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from boring import parse
Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration]
Environment = Dict[str, Identified]
class TypeChecker:
def with_module(self, env: Environment, module: parse.Module) -> bool:
for function in module.functions:
env[function.name.name] = function
found = False
for function in module.functions:
if self.with_function(env, function):
found = True
return found
def with_function(self, env: Environment, function: parse.Function) -> bool:
function_env = env.copy()
for argument in function.arguments:
function_env[argument.name.name] = argument
assert isinstance(function.type, parse.FunctionTypeUsage)
function.block.type = function.type.return_type
return self.with_block(function_env, function.block)
# Skip variable VariableDeclaration
def with_block(self, env: Environment, block: parse.Block) -> bool:
block_env = env.copy()
# if parent is void, must be statement
# if parent is type, must be expression
found = False
final = block.statements[-1]
if isinstance(final, parse.LetStatement):
if block.type is None:
found = True
block.type = parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))
else:
assert block.type == parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))
elif isinstance(final, parse.Expression):
if block.type is None:
if final.type is not None:
found = True
block.type = final.type
else:
if final.type is None:
found = True
final.type = block.type
else:
assert final.type == block.type
for statement in block.statements:
if self.with_statement(block_env, statement):
found = True
return found
def with_statement(self, env: Environment, statement: parse.Statement) -> bool:
if isinstance(statement, parse.LetStatement):
return self.with_let_statement(env, statement)
elif isinstance(statement, parse.Expression): # expression
return self.with_expression(env, statement)
else:
assert False
def with_let_statement(
self, env: Environment, let_statement: parse.LetStatement
) -> bool:
found = False
env[let_statement.variable_name.name] = let_statement
if let_statement.type is None:
if let_statement.expression.type is not None:
let_statement.type = let_statement.expression.type
found = True
else:
if let_statement.expression.type is None:
let_statement.expression.type = let_statement.type
found = True
else:
assert let_statement.expression.type == let_statement.type
if self.with_expression(env, let_statement.expression):
found = True
return found
def with_expression(self, env: Environment, expression: parse.Expression) -> bool:
subexpression = expression.expression
found = False
# generic to all types
if expression.type is None:
if subexpression.type is not None:
expression.type = subexpression.type
found = True
else:
if subexpression.type is None:
subexpression.type = expression.type
found = True
else:
assert subexpression.type == expression.type
if isinstance(subexpression, parse.LiteralInt):
if self.with_literal_int(env, subexpression):
found = True
return found
if isinstance(subexpression, parse.FunctionCall):
if self.with_function_call(env, subexpression):
found = True
return found
if isinstance(subexpression, parse.VariableUsage):
if self.with_variable_usage(env, subexpression):
found = True
return found
if isinstance(subexpression, parse.Operation):
if self.with_operation(env, subexpression):
found = True
return found
assert False
def with_variable_usage(
self, env: Environment, variable_usage: parse.VariableUsage
) -> bool:
found = False
variable = env[variable_usage.name.name]
if variable_usage.type is None:
if variable.type is not None:
variable_usage.type = variable.type
found = True
else:
if variable.type is None:
# print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
# print(f"{variable.name} {variable.type}")
variable.type = variable_usage.type
found = True
else:
assert variable.type == variable_usage.type
return found
def with_operation(self, env: Environment, operation: parse.Operation) -> bool:
found = False
if operation.type is None:
if operation.left.type is not None:
operation.type = operation.left.type
found = True
else:
if operation.left.type is None:
operation.left.type = operation.type
found = True
else:
assert operation.left.type == operation.type
if operation.type is None:
if operation.right.type is not None:
operation.type = operation.right.type
found = True
else:
if operation.right.type is None:
operation.right.type = operation.type
found = True
else:
assert operation.right.type == operation.type
if self.with_expression(env, operation.left):
found = True
if self.with_expression(env, operation.right):
found = True
return found
def with_function_call(
self, env: Environment, function_call: parse.FunctionCall
) -> bool:
found = False
if function_call.type is None:
if function_call.source.type is not None:
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
found = True
function_call.type = function_call.source.type.return_type
else:
if function_call.source.type is not None:
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
assert function_call.type == function_call.source.type.return_type
if self.with_expression(env, function_call.source):
found = True
if function_call.source.type is not None:
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
assert len(function_call.arguments) == len(function_call.source.type.arguments)
for (argument, type_argument) in zip(function_call.arguments, function_call.source.type.arguments):
if argument.type is None:
argument.type = type_argument
found = True
else:
assert argument.type == type_argument
for argument in function_call.arguments:
if self.with_expression(env, argument):
found = True
return found
def with_literal_int(self, env: Environment, literal_int: parse.LiteralInt) -> bool:
ints = [parse.DataTypeUsage(name=parse.Identifier(name=name)) for name in ["Int8", "Int16", "Int32", "Int64", "Int128", "u32"]]
if literal_int.type is not None:
assert literal_int.type in ints, f"{literal_int.type}"
return False

View File

@@ -1,64 +0,0 @@
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from boring import parse
Identified = Union[parse.LetStatement, parse.Function]
Environment = Dict[str, Identified]
class TypeChecker:
def with_module(
self, env: Environment, table: List[TypeComparison], module: parse.Module
) -> bool:
pass
def with_function(
self, env: Environment, table: List[TypeComparison], function: parse.Function
) -> bool:
pass
# Skip variable VariableDeclaration
def with_block(
self, env: Environment, table: List[TypeComparison], block: parse.Block
) -> bool:
pass
def with_statement(
self, env: Environment, table: List[TypeComparison], statement: parse.Statement
) -> bool:
pass
def with_let_statement(
self, env: Environment, table: List[TypeComparison], let_statement: parse.LetStatement
) -> bool:
pass
def with_expression(
self, env: Environment, table: List[TypeComparison], expression: parse.Expression
) -> bool:
pass
def with_variable_usage(
self, env: Environment, table: List[TypeComparison], variable_usage: parse.VariableUsage
) -> bool:
pass
def with_operation(
self, env: Environment, table: List[TypeComparison], operation: parse.Operation
) -> bool:
pass
def with_function_call(
self, env: Environment, table: List[TypeComparison], function_call: parse.FunctionCall
) -> bool:
pass
def with_literal_int(
self, env: Environment, table: List[TypeComparison], literal_int: parse.LiteralInt
) -> bool:
pass