added type checking

This commit is contained in:
Andrew Segavac
2021-05-12 06:40:11 -06:00
parent 9d9d42ebd5
commit cb30ad7040
5 changed files with 355 additions and 163 deletions

View File

@@ -1,24 +1,29 @@
import sys
from typing import List
from boring.parse import boring_parser, TreeToBoring, pretty_print
from boring.type_checker import TypeCheckTableBuilder, TypeComparison, check_types
from boring.type_checking import TypeChecker
if __name__ == "__main__":
with open(sys.argv[1]) as f:
tree = boring_parser.parse(f.read())
# print(tree)
result = TreeToBoring().transform(tree)
# pretty_print(result)
type_checker = TypeChecker()
while type_checker.with_module({}, result):
print('loop')
# type_checker.with_module({}, result)
pretty_print(result)
tctb = TypeCheckTableBuilder()
table: List[TypeComparison] = []
tctb.with_module({}, table, result)
for e in table:
print(e)
print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')
check_types(table)
print('vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv')
for e in table:
print(e)
# tctb = TypeCheckTableBuilder()
# table: List[TypeComparison] = []
# tctb.with_module({}, table, result)
# for e in table:
# print(e)
# print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
# check_types(table)
# print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
# for e in table:
# print(e)
# None, Some
# None skip, set

View File

@@ -35,7 +35,9 @@ class Identifier:
@dataclass
class FunctionTypeUsage:
arguments: List["TypeUsage"] # Specified if it is a function, this is how you tell if it's a function
arguments: List[
"TypeUsage"
] # Specified if it is a function, this is how you tell if it's a function
return_type: "TypeUsage"
@@ -56,40 +58,39 @@ class Operator(enum.Enum):
@dataclass
class LiteralInt:
id: str
value: int
type: Optional[TypeUsage]
@dataclass
class FunctionCall:
id: str
source: Expression
source: "Expression"
arguments: List["Expression"]
type: Optional[TypeUsage]
@dataclass
class Operation:
id: str
left: "Expression"
op: Operator
right: "Expression"
type: Optional[TypeUsage]
@dataclass
class VariableUsage:
id: str
name: Identifier
type: Optional[TypeUsage]
@dataclass
class Expression:
id: str
expression: Union[LiteralInt, FunctionCall, VariableUsage, Operation]
type: Optional[TypeUsage]
@dataclass
class LetStatement:
id: str
variable_name: Identifier
type: Optional[TypeUsage]
expression: Expression
@@ -100,29 +101,27 @@ Statement = Union[LetStatement, Expression]
@dataclass
class Block:
id: str
statements: List[Statement]
type: Optional[TypeUsage]
@dataclass
class VariableDeclaration:
id: str
name: Identifier
type: TypeUsage
@dataclass
class Function:
id: str
name: Identifier
arguments: List[VariableDeclaration]
block: Block
return_type: TypeUsage
type: TypeUsage
@dataclass
class Module:
id: str
functions: List[Function]
@@ -196,7 +195,6 @@ boring_grammar = r"""
class TreeToBoring(Transformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.id = 0
def plus(self, p) -> Operator:
return Operator.plus
@@ -212,8 +210,7 @@ class TreeToBoring(Transformer):
def literal_int(self, n) -> LiteralInt:
(n,) = n
self.id += 1
return LiteralInt(value=int(n), id=str(self.id))
return LiteralInt(value=int(n), type=None)
def identifier(self, i) -> Identifier:
(i,) = i
@@ -221,64 +218,52 @@ class TreeToBoring(Transformer):
def variable_usage(self, variable) -> VariableUsage:
(variable,) = variable
self.id += 1
return VariableUsage(name=variable, id=str(self.id))
return VariableUsage(name=variable, type=None)
def function_call(self, call) -> FunctionCall:
self.id += 1
return FunctionCall(source=call[0], arguments=call[1:], id=str(self.id))
return FunctionCall(source=call[0], arguments=call[1:], type=None)
def add_expression(self, ae) -> Operation:
self.id += 1
return Operation(left=ae[0], op=ae[1], right=ae[2], id=str(self.id))
return Operation(left=ae[0], op=ae[1], right=ae[2], type=None)
def sub_expression(self, se) -> Operation:
self.id += 1
return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id))
return Operation(left=se[0], op=se[1], right=se[2], type=None)
def mult_expression(self, se) -> Operation:
self.id += 1
return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id))
return Operation(left=se[0], op=se[1], right=se[2], type=None)
def div_expression(self, se) -> Operation:
self.id += 1
return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id))
return Operation(left=se[0], op=se[1], right=se[2], type=None)
def expression(self, exp) -> Expression:
(exp,) = exp
if isinstance(exp, Expression):
return exp
self.id += 1
return Expression(expression=exp, id=str(self.id))
return Expression(expression=exp, type=None)
def factor(self, factor) -> Expression:
(factor,) = factor
if isinstance(factor, Expression):
return factor
self.id += 1
return Expression(expression=factor, id=str(self.id))
return Expression(expression=factor, type=None)
def term(self, term) -> Expression:
(term,) = term
self.id += 1
return Expression(expression=term, id=str(self.id))
return Expression(expression=term, type=None)
def let_statement(self, let_statement) -> LetStatement:
self.id += 1
if len(let_statement) == 3:
(variable_name, type_usage, expression) = let_statement
return LetStatement(
variable_name=variable_name,
type=type_usage,
expression=expression,
id=str(self.id),
)
(variable_name, expression) = let_statement
return LetStatement(
variable_name=variable_name,
type=None,
expression=expression,
id=str(self.id),
)
def statement(self, statement):
@@ -286,23 +271,20 @@ class TreeToBoring(Transformer):
return statement
def block(self, block) -> Block:
self.id += 1
return Block(statements=block, id=str(self.id))
return Block(statements=block, type=None)
def data_type(self, name) -> TypeUsage:
(name,) = name
self.id += 1
return DataTypeUsage(name=name)
def function_type(self, type_usage) -> TypeUsage:
self.id += 1
return FunctionTypeUsage(arguments=type_usage, return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)))
return FunctionTypeUsage(
arguments=type_usage,
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
)
def function_type_with_return(self, type_usage) -> TypeUsage:
self.id += 1
return FunctionTypeUsage(
arguments=type_usage[0:-1], return_type=type_usage[-1]
)
return FunctionTypeUsage(arguments=type_usage[0:-1], return_type=type_usage[-1])
def type_usage(self, type_usage):
(type_usage,) = type_usage
@@ -310,27 +292,29 @@ class TreeToBoring(Transformer):
def variable_declaration(self, identifier) -> VariableDeclaration:
(identifier, type_usage) = identifier
self.id += 1
return VariableDeclaration(name=identifier, type=type_usage, id=str(self.id))
return VariableDeclaration(name=identifier, type=type_usage)
def function_without_return(self, function) -> Function:
self.id += 1
return Function(
id=str(self.id),
name=function[0],
arguments=function[1:-1],
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
block=function[-1],
type=FunctionTypeUsage(
arguments=[arg.type for arg in function[1:-1]],
return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)),
),
)
def function_with_return(self, function) -> Function:
self.id += 1
return Function(
id=str(self.id),
name=function[0],
arguments=function[1:-2],
return_type=function[-2],
block=function[-1],
type=FunctionTypeUsage(
arguments=[arg.type for arg in function[1:-2]], return_type=function[-2]
),
)
def function(self, function):
@@ -338,8 +322,7 @@ class TreeToBoring(Transformer):
return function
def module(self, functions) -> Module:
self.id += 1
return Module(id=str(self.id), functions=functions)
return Module(functions=functions)
boring_parser = Lark(boring_grammar, start="module", lexer="standard")

View File

@@ -19,6 +19,7 @@ class FunctionCallTypeComparison:
to_id: str
type_usage: Optional[parse.TypeUsage]
@dataclass
class FunctionArgumentTypeComparison:
from_id: str
@@ -27,7 +28,9 @@ class FunctionArgumentTypeComparison:
type_usage: Optional[parse.TypeUsage]
TypeComparison = Union[EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison]
TypeComparison = Union[
EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison
]
Environment = Dict[str, str]
@@ -42,7 +45,9 @@ class TypeCheckTableBuilder:
return_type=function.return_type,
)
table.append(
EqualityTypeComparison(from_id=None, to_id=function.id, type_usage=type_usage)
EqualityTypeComparison(
from_id=None, to_id=function.id, type_usage=type_usage
)
)
for function in module.functions:
self.with_function(env, table, function)
@@ -54,10 +59,14 @@ class TypeCheckTableBuilder:
for argument in function.arguments:
function_env[argument.name.name] = argument.id
table.append(
EqualityTypeComparison(from_id=None, to_id=argument.id, type_usage=argument.type)
EqualityTypeComparison(
from_id=None, to_id=argument.id, type_usage=argument.type
)
)
table.append(
EqualityTypeComparison(from_id=None, to_id=function.block.id, type_usage=function.return_type)
EqualityTypeComparison(
from_id=None, to_id=function.block.id, type_usage=function.return_type
)
)
self.with_block(function_env, table, function.block)
@@ -71,11 +80,19 @@ class TypeCheckTableBuilder:
# if parent is type, must be expression
if isinstance(block.statements[-1], parse.Expression):
table.append(
EqualityTypeComparison(from_id=block.statements[-1].id, to_id=block.id, type_usage=None)
EqualityTypeComparison(
from_id=block.statements[-1].id, to_id=block.id, type_usage=None
)
)
else:
table.append(
EqualityTypeComparison(from_id=None, to_id=block.id, type_usage=parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE)))
EqualityTypeComparison(
from_id=None,
to_id=block.id,
type_usage=parse.DataTypeUsage(
name=parse.Identifier(name=parse.UNIT_TYPE)
),
)
)
for statement in block.statements:
print(statement)
@@ -86,73 +103,115 @@ class TypeCheckTableBuilder:
):
if isinstance(statement, parse.LetStatement):
self.with_let_statement(env, table, statement)
elif isinstance(statement, parse.Expression): # expression
elif isinstance(statement, parse.Expression): # expression
self.with_expression(env, table, statement)
else:
assert False
def with_let_statement(
self, env: Environment, table: List[TypeComparison], let_statement: parse.LetStatement
self,
env: Environment,
table: List[TypeComparison],
let_statement: parse.LetStatement,
):
env[let_statement.variable_name.name] = let_statement.id
table.append(
EqualityTypeComparison(from_id=let_statement.expression.id, to_id=let_statement.id, type_usage=let_statement.type)
EqualityTypeComparison(
from_id=let_statement.expression.id,
to_id=let_statement.id,
type_usage=let_statement.type,
)
)
self.with_expression(env, table, let_statement.expression)
def with_expression(
self, env: Environment, table: List[TypeComparison], expression: parse.Expression
self,
env: Environment,
table: List[TypeComparison],
expression: parse.Expression,
):
if isinstance(expression.expression, parse.LiteralInt):
table.append(
EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None)
EqualityTypeComparison(
from_id=expression.expression.id,
to_id=expression.id,
type_usage=None,
)
)
self.with_literal_int(env, table, expression.expression)
elif isinstance(expression.expression, parse.FunctionCall):
table.append(
EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None)
EqualityTypeComparison(
from_id=expression.expression.id,
to_id=expression.id,
type_usage=None,
)
)
self.with_function_call(env, table, expression.expression)
elif isinstance(expression.expression, parse.VariableUsage):
table.append(
EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None)
EqualityTypeComparison(
from_id=expression.expression.id,
to_id=expression.id,
type_usage=None,
)
)
self.with_variable_usage(env, table, expression.expression)
elif isinstance(expression.expression, parse.Operation):
table.append(
EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None)
EqualityTypeComparison(
from_id=expression.expression.id,
to_id=expression.id,
type_usage=None,
)
)
self.with_operation(env, table, expression.expression)
else:
assert False
def with_variable_usage(
self, env: Environment, table: List[TypeComparison], variable_usage: parse.VariableUsage
self,
env: Environment,
table: List[TypeComparison],
variable_usage: parse.VariableUsage,
):
print('%%%%%%%%%%%%%%%%%%%%%')
print("%%%%%%%%%%%%%%%%%%%%%")
print(env[variable_usage.name.name])
print(variable_usage.id)
table.append(
EqualityTypeComparison(from_id=env[variable_usage.name.name], to_id=variable_usage.id, type_usage=None)
EqualityTypeComparison(
from_id=env[variable_usage.name.name],
to_id=variable_usage.id,
type_usage=None,
)
)
def with_operation(
self, env: Environment, table: List[TypeComparison], operation: parse.Operation
):
table.append(
EqualityTypeComparison(from_id=operation.left.id, to_id=operation.id, type_usage=None)
EqualityTypeComparison(
from_id=operation.left.id, to_id=operation.id, type_usage=None
)
)
table.append(
EqualityTypeComparison(from_id=operation.right.id, to_id=operation.id, type_usage=None)
EqualityTypeComparison(
from_id=operation.right.id, to_id=operation.id, type_usage=None
)
)
self.with_expression(env, table, operation.left)
self.with_expression(env, table, operation.right)
def with_function_call(
self, env: Environment, table: List[TypeComparison], function_call: parse.FunctionCall
self,
env: Environment,
table: List[TypeComparison],
function_call: parse.FunctionCall,
):
table.append(
EqualityTypeComparison(to_id=function_call.id, from_id=function_call.source.id, type_usage=None)
EqualityTypeComparison(
to_id=function_call.id, from_id=function_call.source.id, type_usage=None
)
)
self.with_expression(env, table, function_call.source)
@@ -163,14 +222,20 @@ class TypeCheckTableBuilder:
# FunctionArgumentTypeComparison
self.with_expression(env, table, argument)
def with_literal_int(
self, env: Environment, table: List[TypeComparison], literal_int: parse.LiteralInt
self,
env: Environment,
table: List[TypeComparison],
literal_int: parse.LiteralInt,
):
table.append(
EqualityTypeComparison(from_id=None, to_id=literal_int.id, type_usage=parse.DataTypeUsage(
name=parse.Identifier(name="u32"),
))
EqualityTypeComparison(
from_id=None,
to_id=literal_int.id,
type_usage=parse.DataTypeUsage(
name=parse.Identifier(name="u32"),
),
)
)
@@ -201,7 +266,9 @@ def check_types(table: List[TypeComparison]):
other.type_usage = entry.type_usage
found = True
elif isinstance(other, FunctionCallTypeComparison):
assert isinstance(entry.type_usage, parse.FunctionTypeUsage), 'non function called'
assert isinstance(
entry.type_usage, parse.FunctionTypeUsage
), "non function called"
other.type_usage = entry.type_usage.return_type
found = True
elif other.type_usage is not None and entry.type_usage is None:
@@ -209,20 +276,16 @@ def check_types(table: List[TypeComparison]):
entry.type_usage = other.type_usage
found = True
elif isinstance(other, FunctionCallTypeComparison):
pass # can't reverse a function
pass # can't reverse a function
else:
if isinstance(other, EqualityTypeComparison):
assert other.type_usage == entry.type_usage
elif isinstance(other, FunctionCallTypeComparison):
assert isinstance(entry.type_usage, parse.FunctionTypeUsage), 'non function called'
assert isinstance(
entry.type_usage, parse.FunctionTypeUsage
), "non function called"
assert other.type_usage == entry.type_usage.return_type
# if other.to_id == entry.from_id:
# # let a = || {4} other
# # let b = a() entry

205
boring/type_checking.py Normal file
View File

@@ -0,0 +1,205 @@
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from boring import parse
Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration]
Environment = Dict[str, Identified]
class TypeChecker:
def with_module(self, env: Environment, module: parse.Module) -> bool:
for function in module.functions:
env[function.name.name] = function
found = False
for function in module.functions:
if self.with_function(env, function):
found = True
return found
def with_function(self, env: Environment, function: parse.Function) -> bool:
function_env = env.copy()
for argument in function.arguments:
function_env[argument.name.name] = argument
assert isinstance(function.type, parse.FunctionTypeUsage)
function.block.type = function.type.return_type
return self.with_block(function_env, function.block)
# Skip variable VariableDeclaration
def with_block(self, env: Environment, block: parse.Block) -> bool:
block_env = env.copy()
# if parent is void, must be statement
# if parent is type, must be expression
found = False
final = block.statements[-1]
if isinstance(final, parse.LetStatement):
if block.type is None:
found = True
block.type = parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))
else:
assert block.type == parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))
elif isinstance(final, parse.Expression):
if block.type is None:
if final.type is not None:
found = True
block.type = final.type
else:
if final.type is None:
found = True
final.type = block.type
else:
assert final.type == block.type
for statement in block.statements:
if self.with_statement(block_env, statement):
found = True
return found
def with_statement(self, env: Environment, statement: parse.Statement) -> bool:
if isinstance(statement, parse.LetStatement):
return self.with_let_statement(env, statement)
elif isinstance(statement, parse.Expression): # expression
return self.with_expression(env, statement)
else:
assert False
def with_let_statement(
self, env: Environment, let_statement: parse.LetStatement
) -> bool:
found = False
env[let_statement.variable_name.name] = let_statement
if let_statement.type is None:
if let_statement.expression.type is not None:
let_statement.type = let_statement.expression.type
found = True
else:
if let_statement.expression.type is None:
let_statement.expression.type = let_statement.type
found = True
else:
assert let_statement.expression.type == let_statement.type
if self.with_expression(env, let_statement.expression):
found = True
return found
def with_expression(self, env: Environment, expression: parse.Expression) -> bool:
subexpression = expression.expression
found = False
# generic to all types
if expression.type is None:
if subexpression.type is not None:
expression.type = subexpression.type
found = True
else:
if subexpression.type is None:
subexpression.type = expression.type
found = True
else:
assert subexpression.type == expression.type
if isinstance(subexpression, parse.LiteralInt):
if self.with_literal_int(env, subexpression):
found = True
return found
if isinstance(subexpression, parse.FunctionCall):
if self.with_function_call(env, subexpression):
found = True
return found
if isinstance(subexpression, parse.VariableUsage):
if self.with_variable_usage(env, subexpression):
found = True
return found
if isinstance(subexpression, parse.Operation):
if self.with_operation(env, subexpression):
found = True
return found
assert False
def with_variable_usage(
self, env: Environment, variable_usage: parse.VariableUsage
) -> bool:
found = False
variable = env[variable_usage.name.name]
if variable_usage.type is None:
if variable.type is not None:
variable_usage.type = variable.type
found = True
else:
if variable.type is None:
# print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
# print(f"{variable.name} {variable.type}")
variable.type = variable_usage.type
found = True
else:
assert variable.type == variable_usage.type
return found
def with_operation(self, env: Environment, operation: parse.Operation) -> bool:
found = False
if operation.type is None:
if operation.left.type is not None:
operation.type = operation.left.type
found = True
else:
if operation.left.type is None:
operation.left.type = operation.type
found = True
else:
assert operation.left.type == operation.type
if operation.type is None:
if operation.right.type is not None:
operation.type = operation.right.type
found = True
else:
if operation.right.type is None:
operation.right.type = operation.type
found = True
else:
assert operation.right.type == operation.type
if self.with_expression(env, operation.left):
found = True
if self.with_expression(env, operation.right):
found = True
return found
def with_function_call(
self, env: Environment, function_call: parse.FunctionCall
) -> bool:
found = False
if function_call.type is None:
if function_call.source.type is not None:
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
found = True
function_call.type = function_call.source.type.return_type
else:
if function_call.source.type is not None:
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
assert function_call.type == function_call.source.type.return_type
if self.with_expression(env, function_call.source):
found = True
if function_call.source.type is not None:
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
assert len(function_call.arguments) == len(function_call.source.type.arguments)
for (argument, type_argument) in zip(function_call.arguments, function_call.source.type.arguments):
if argument.type is None:
argument.type = type_argument
found = True
else:
assert argument.type == type_argument
for argument in function_call.arguments:
if self.with_expression(env, argument):
found = True
return found
def with_literal_int(self, env: Environment, literal_int: parse.LiteralInt) -> bool:
ints = [parse.DataTypeUsage(name=parse.Identifier(name=name)) for name in ["Int8", "Int16", "Int32", "Int64", "Int128", "u32"]]
if literal_int.type is not None:
assert literal_int.type in ints, f"{literal_int.type}"
return False

View File

@@ -1,64 +0,0 @@
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from boring import parse
Identified = Union[parse.LetStatement, parse.Function]
Environment = Dict[str, Identified]
class TypeChecker:
def with_module(
self, env: Environment, table: List[TypeComparison], module: parse.Module
) -> bool:
pass
def with_function(
self, env: Environment, table: List[TypeComparison], function: parse.Function
) -> bool:
pass
# Skip variable VariableDeclaration
def with_block(
self, env: Environment, table: List[TypeComparison], block: parse.Block
) -> bool:
pass
def with_statement(
self, env: Environment, table: List[TypeComparison], statement: parse.Statement
) -> bool:
pass
def with_let_statement(
self, env: Environment, table: List[TypeComparison], let_statement: parse.LetStatement
) -> bool:
pass
def with_expression(
self, env: Environment, table: List[TypeComparison], expression: parse.Expression
) -> bool:
pass
def with_variable_usage(
self, env: Environment, table: List[TypeComparison], variable_usage: parse.VariableUsage
) -> bool:
pass
def with_operation(
self, env: Environment, table: List[TypeComparison], operation: parse.Operation
) -> bool:
pass
def with_function_call(
self, env: Environment, table: List[TypeComparison], function_call: parse.FunctionCall
) -> bool:
pass
def with_literal_int(
self, env: Environment, table: List[TypeComparison], literal_int: parse.LiteralInt
) -> bool:
pass