diff --git a/Dockerfile-python b/Dockerfile-python index 5b43914..2be5cb5 100644 --- a/Dockerfile-python +++ b/Dockerfile-python @@ -1,3 +1,3 @@ FROM python:3.9 -RUN pip install lark mypy +RUN pip install lark mypy black diff --git a/boring/__init__.py b/boring/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/boring/interpret.py b/boring/interpret.py index 8204571..f552764 100644 --- a/boring/interpret.py +++ b/boring/interpret.py @@ -22,19 +22,28 @@ class Interpreter: assert len(function_definition.arguments) == len(function_call.arguments) for i, argument in enumerate(function_definition.arguments): - new_env.identifiers[argument.name.name] = self.handle_expression(env, function_call.arguments[i]) + new_env.identifiers[argument.name.name] = self.handle_expression( + env, function_call.arguments[i] + ) return self.handle_block(new_env, function_definition.block) - def handle_operation(self, env, operation): if operation.op == parse.Operator.plus: - return self.handle_expression(env, operation.left) + self.handle_expression(env, operation.right) + return self.handle_expression(env, operation.left) + self.handle_expression( + env, operation.right + ) elif operation.op == parse.Operator.minus: - return self.handle_expression(env, operation.left) - self.handle_expression(env, operation.right) + return self.handle_expression(env, operation.left) - self.handle_expression( + env, operation.right + ) elif operation.op == parse.Operator.mult: - return self.handle_expression(env, operation.left) * self.handle_expression(env, operation.right) + return self.handle_expression(env, operation.left) * self.handle_expression( + env, operation.right + ) elif operation.op == parse.Operator.div: - return self.handle_expression(env, operation.left) / self.handle_expression(env, operation.right) + return self.handle_expression(env, operation.left) / self.handle_expression( + env, operation.right + ) def handle_expression(self, env, expression): if type(expression.expression) == parse.LiteralInt: @@ -58,13 +67,15 @@ class Interpreter: for function in module.functions: env.identifiers[function.name.name] = function - if 'main' not in env.identifiers: + if "main" not in env.identifiers: raise Exception("must have main function") - return self.handle_function_call(env, parse.FunctionCall(name=parse.Identifier("main"), arguments=[])) + return self.handle_function_call( + env, parse.FunctionCall(name=parse.Identifier("main"), arguments=[]) + ) -if __name__ == '__main__': +if __name__ == "__main__": with open(sys.argv[1]) as f: tree = parse.boring_parser.parse(f.read()) # print(tree) diff --git a/boring/main.py b/boring/main.py new file mode 100644 index 0000000..0579543 --- /dev/null +++ b/boring/main.py @@ -0,0 +1,25 @@ +import sys +from typing import List +from boring.parse import boring_parser, TreeToBoring, pretty_print +from boring.type_checker import TypeCheckTableBuilder, TypeComparison, check_types + +if __name__ == "__main__": + with open(sys.argv[1]) as f: + tree = boring_parser.parse(f.read()) + # print(tree) + result = TreeToBoring().transform(tree) + pretty_print(result) + tctb = TypeCheckTableBuilder() + table: List[TypeComparison] = [] + tctb.with_module({}, table, result) + for e in table: + print(e) + print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^') + check_types(table) + print('vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv') + for e in table: + print(e) + +# None, Some +# None skip, set +# Some set, check diff --git a/boring/parse.py b/boring/parse.py index b310f99..5a29254 100644 --- a/boring/parse.py +++ b/boring/parse.py @@ -1,23 +1,31 @@ import sys import enum -from typing import Union, List +from typing import Union, List, Optional from dataclasses import dataclass, field from lark import Lark, Transformer + def pretty_print(clas, indent=0): - print(' ' * indent + type(clas).__name__ + ':') + print(" " * indent + type(clas).__name__ + ":") + if type(clas) == list: + for e in clas: + pretty_print(e) + return indent += 2 - for k,v in clas.__dict__.items(): - if '__dict__' in dir(v): - print(' ' * indent + k + ': ') - pretty_print(v,indent+2) + for k, v in clas.__dict__.items(): + if "__dict__" in dir(v): + print(" " * indent + k + ": ") + pretty_print(v, indent + 2) elif type(v) == list: - print(' ' * indent + k + ': ' "[") + print(" " * indent + k + ": " "[") for e in v: - pretty_print(e, indent+2) - print(' ' * indent + "]") + pretty_print(e, indent + 2) + print(" " * indent + "]") else: - print(' ' * indent + k + ': ' + str(v)) + print(" " * indent + k + ": " + str(v)) + + +UNIT_TYPE = "()" @dataclass @@ -25,6 +33,20 @@ class Identifier: name: str +@dataclass +class FunctionTypeUsage: + arguments: List["TypeUsage"] # Specified if it is a function, this is how you tell if it's a function + return_type: "TypeUsage" + + +@dataclass +class DataTypeUsage: + name: Identifier + + +TypeUsage = Union[FunctionTypeUsage, DataTypeUsage] + + class Operator(enum.Enum): mult = "mult" div = "div" @@ -34,59 +56,74 @@ class Operator(enum.Enum): @dataclass class LiteralInt: + id: str value: int @dataclass class FunctionCall: - name: Identifier - arguments: List['Expression'] = field(default_factory=list) + id: str + source: Expression + arguments: List["Expression"] @dataclass class Operation: - left: 'Expression' + id: str + left: "Expression" op: Operator - right: 'Expression' + right: "Expression" + + +@dataclass +class VariableUsage: + id: str + name: Identifier @dataclass class Expression: - expression: Union[LiteralInt,FunctionCall,Identifier,Operation,'Expression'] + id: str + expression: Union[LiteralInt, FunctionCall, VariableUsage, Operation] @dataclass class LetStatement: + id: str variable_name: Identifier + type: Optional[TypeUsage] expression: Expression -@dataclass -class Statement: - statement: Union[LetStatement, Expression] +Statement = Union[LetStatement, Expression] @dataclass class Block: + id: str statements: List[Statement] @dataclass class VariableDeclaration: + id: str name: Identifier + type: TypeUsage @dataclass class Function: + id: str name: Identifier arguments: List[VariableDeclaration] block: Block + return_type: TypeUsage @dataclass class Module: - functions: Function - + id: str + functions: List[Function] boring_grammar = r""" @@ -97,13 +134,16 @@ boring_grammar = r""" literal_int: SIGNED_NUMBER identifier : NAME - function_call : identifier "(" [expression ("," expression)*] ")" + + function_call : expression "(" [expression ("," expression)*] ")" add_expression : expression plus factor sub_expression : expression minus factor mult_expression : expression mult term div_expression : expression div term + variable_usage : identifier + expression : add_expression | sub_expression | factor @@ -113,20 +153,36 @@ boring_grammar = r""" | term term : literal_int - | identifier + | variable_usage | function_call | "(" expression ")" let_statement : "let" identifier "=" expression ";" + | "let" identifier ":" type_usage "=" expression ";" statement : let_statement | expression block : "{" (statement)* "}" - variable_declaration : identifier + data_type : identifier - function : "fn" identifier "(" [variable_declaration ("," variable_declaration)*] ")" block + function_type : "fn" "(" (type_usage)* ")" + + function_type_with_return : "fn" "(" (type_usage)* ")" ":" type_usage + + type_usage : data_type + | function_type + | function_type_with_return + + variable_declaration : identifier ":" type_usage + + function_without_return : "fn" identifier "(" [variable_declaration ("," variable_declaration)*] ")" block + + function_with_return : "fn" identifier "(" [variable_declaration ("," variable_declaration)*] ")" ":" type_usage block + + function : function_with_return + | function_without_return module : (function)* @@ -136,7 +192,12 @@ boring_grammar = r""" %ignore WS """ + class TreeToBoring(Transformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.id = 0 + def plus(self, p) -> Operator: return Operator.plus @@ -151,68 +212,139 @@ class TreeToBoring(Transformer): def literal_int(self, n) -> LiteralInt: (n,) = n - return LiteralInt(value=int(n)) + self.id += 1 + return LiteralInt(value=int(n), id=str(self.id)) def identifier(self, i) -> Identifier: (i,) = i return Identifier(name=str(i)) + def variable_usage(self, variable) -> VariableUsage: + (variable,) = variable + self.id += 1 + return VariableUsage(name=variable, id=str(self.id)) + def function_call(self, call) -> FunctionCall: - return FunctionCall(name=call[0], arguments=call[1:]) + self.id += 1 + return FunctionCall(source=call[0], arguments=call[1:], id=str(self.id)) def add_expression(self, ae) -> Operation: - return Operation(left=ae[0], op=ae[1], right=ae[2]) + self.id += 1 + return Operation(left=ae[0], op=ae[1], right=ae[2], id=str(self.id)) def sub_expression(self, se) -> Operation: - return Operation(left=se[0], op=se[1], right=se[2]) + self.id += 1 + return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id)) def mult_expression(self, se) -> Operation: - return Operation(left=se[0], op=se[1], right=se[2]) + self.id += 1 + return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id)) def div_expression(self, se) -> Operation: - return Operation(left=se[0], op=se[1], right=se[2]) + self.id += 1 + return Operation(left=se[0], op=se[1], right=se[2], id=str(self.id)) def expression(self, exp) -> Expression: (exp,) = exp if isinstance(exp, Expression): return exp - return Expression(expression=exp) + self.id += 1 + return Expression(expression=exp, id=str(self.id)) def factor(self, factor) -> Expression: (factor,) = factor if isinstance(factor, Expression): return factor - return Expression(factor) + self.id += 1 + return Expression(expression=factor, id=str(self.id)) def term(self, term) -> Expression: (term,) = term - return Expression(term) + self.id += 1 + return Expression(expression=term, id=str(self.id)) def let_statement(self, let_statement) -> LetStatement: + self.id += 1 + if len(let_statement) == 3: + (variable_name, type_usage, expression) = let_statement + return LetStatement( + variable_name=variable_name, + type=type_usage, + expression=expression, + id=str(self.id), + ) (variable_name, expression) = let_statement - return LetStatement(variable_name=variable_name, expression=expression) + return LetStatement( + variable_name=variable_name, + type=None, + expression=expression, + id=str(self.id), + ) - def statement(self, statement) -> Statement: + def statement(self, statement): (statement,) = statement - return Statement(statement=statement) + return statement def block(self, block) -> Block: - return Block(statements=block) + self.id += 1 + return Block(statements=block, id=str(self.id)) + + def data_type(self, name) -> TypeUsage: + (name,) = name + self.id += 1 + return DataTypeUsage(name=name) + + def function_type(self, type_usage) -> TypeUsage: + self.id += 1 + return FunctionTypeUsage(arguments=type_usage, return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE))) + + def function_type_with_return(self, type_usage) -> TypeUsage: + self.id += 1 + return FunctionTypeUsage( + arguments=type_usage[0:-1], return_type=type_usage[-1] + ) + + def type_usage(self, type_usage): + (type_usage,) = type_usage + return type_usage def variable_declaration(self, identifier) -> VariableDeclaration: - (identifier,) = identifier - return VariableDeclaration(name=identifier) + (identifier, type_usage) = identifier + self.id += 1 + return VariableDeclaration(name=identifier, type=type_usage, id=str(self.id)) - def function(self, function) -> Function: - return Function(name=function[0], arguments=function[1:-1], block=function[-1]) + def function_without_return(self, function) -> Function: + self.id += 1 + return Function( + id=str(self.id), + name=function[0], + arguments=function[1:-1], + return_type=DataTypeUsage(name=Identifier(name=UNIT_TYPE)), + block=function[-1], + ) + + def function_with_return(self, function) -> Function: + self.id += 1 + return Function( + id=str(self.id), + name=function[0], + arguments=function[1:-2], + return_type=function[-2], + block=function[-1], + ) + + def function(self, function): + (function,) = function + return function def module(self, functions) -> Module: - return Module(functions=functions) + self.id += 1 + return Module(id=str(self.id), functions=functions) -boring_parser = Lark(boring_grammar, start='module', lexer='standard') +boring_parser = Lark(boring_grammar, start="module", lexer="standard") -if __name__ == '__main__': +if __name__ == "__main__": with open(sys.argv[1]) as f: tree = boring_parser.parse(f.read()) # print(tree) diff --git a/boring/type_checker.py b/boring/type_checker.py new file mode 100644 index 0000000..40a5070 --- /dev/null +++ b/boring/type_checker.py @@ -0,0 +1,254 @@ +from dataclasses import dataclass +from typing import List, Dict, Optional, Union + + +from boring import parse + + +@dataclass +class EqualityTypeComparison: + from_id: Optional[str] + to_id: str + type_usage: Optional[parse.TypeUsage] + # constraints: List[Constraint] + + +@dataclass +class FunctionCallTypeComparison: + from_id: str + to_id: str + type_usage: Optional[parse.TypeUsage] + +@dataclass +class FunctionArgumentTypeComparison: + from_id: str + to_id: str + argument_id: int + type_usage: Optional[parse.TypeUsage] + + +TypeComparison = Union[EqualityTypeComparison, FunctionCallTypeComparison, FunctionArgumentTypeComparison] +Environment = Dict[str, str] + + +class TypeCheckTableBuilder: + def with_module( + self, env: Environment, table: List[TypeComparison], module: parse.Module + ): + for function in module.functions: + env[function.name.name] = function.id + type_usage = parse.FunctionTypeUsage( + arguments=[arg.type for arg in function.arguments], + return_type=function.return_type, + ) + table.append( + EqualityTypeComparison(from_id=None, to_id=function.id, type_usage=type_usage) + ) + for function in module.functions: + self.with_function(env, table, function) + + def with_function( + self, env: Environment, table: List[TypeComparison], function: parse.Function + ): + function_env = env.copy() + for argument in function.arguments: + function_env[argument.name.name] = argument.id + table.append( + EqualityTypeComparison(from_id=None, to_id=argument.id, type_usage=argument.type) + ) + table.append( + EqualityTypeComparison(from_id=None, to_id=function.block.id, type_usage=function.return_type) + ) + self.with_block(function_env, table, function.block) + + # Skip variable VariableDeclaration + + def with_block( + self, env: Environment, table: List[TypeComparison], block: parse.Block + ): + block_env = env.copy() + # if parent is void, must be statement + # if parent is type, must be expression + if isinstance(block.statements[-1], parse.Expression): + table.append( + EqualityTypeComparison(from_id=block.statements[-1].id, to_id=block.id, type_usage=None) + ) + else: + table.append( + EqualityTypeComparison(from_id=None, to_id=block.id, type_usage=parse.DataTypeUsage(name=parse.Identifier(name=parse.UNIT_TYPE))) + ) + for statement in block.statements: + print(statement) + self.with_statement(block_env, table, statement) + + def with_statement( + self, env: Environment, table: List[TypeComparison], statement: parse.Statement + ): + if isinstance(statement, parse.LetStatement): + self.with_let_statement(env, table, statement) + elif isinstance(statement, parse.Expression): # expression + self.with_expression(env, table, statement) + else: + assert False + + def with_let_statement( + self, env: Environment, table: List[TypeComparison], let_statement: parse.LetStatement + ): + env[let_statement.variable_name.name] = let_statement.id + table.append( + EqualityTypeComparison(from_id=let_statement.expression.id, to_id=let_statement.id, type_usage=let_statement.type) + ) + self.with_expression(env, table, let_statement.expression) + + def with_expression( + self, env: Environment, table: List[TypeComparison], expression: parse.Expression + ): + if isinstance(expression.expression, parse.LiteralInt): + table.append( + EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) + ) + self.with_literal_int(env, table, expression.expression) + elif isinstance(expression.expression, parse.FunctionCall): + table.append( + EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) + ) + self.with_function_call(env, table, expression.expression) + elif isinstance(expression.expression, parse.VariableUsage): + table.append( + EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) + ) + self.with_variable_usage(env, table, expression.expression) + elif isinstance(expression.expression, parse.Operation): + table.append( + EqualityTypeComparison(from_id=expression.expression.id, to_id=expression.id, type_usage=None) + ) + self.with_operation(env, table, expression.expression) + else: + assert False + + def with_variable_usage( + self, env: Environment, table: List[TypeComparison], variable_usage: parse.VariableUsage + ): + print('%%%%%%%%%%%%%%%%%%%%%') + print(env[variable_usage.name.name]) + print(variable_usage.id) + table.append( + EqualityTypeComparison(from_id=env[variable_usage.name.name], to_id=variable_usage.id, type_usage=None) + ) + + def with_operation( + self, env: Environment, table: List[TypeComparison], operation: parse.Operation + ): + table.append( + EqualityTypeComparison(from_id=operation.left.id, to_id=operation.id, type_usage=None) + ) + table.append( + EqualityTypeComparison(from_id=operation.right.id, to_id=operation.id, type_usage=None) + ) + self.with_expression(env, table, operation.left) + self.with_expression(env, table, operation.right) + + def with_function_call( + self, env: Environment, table: List[TypeComparison], function_call: parse.FunctionCall + ): + table.append( + EqualityTypeComparison(to_id=function_call.id, from_id=function_call.source.id, type_usage=None) + ) + self.with_expression(env, table, function_call.source) + + for i, argument in enumerate(function_call.arguments): + # table.append( + # FunctionArgumentTypeComparison(from_id=env[function_call.name.name], to_id=function_call.id, type_usage=None) + # ) + # FunctionArgumentTypeComparison + self.with_expression(env, table, argument) + + + def with_literal_int( + self, env: Environment, table: List[TypeComparison], literal_int: parse.LiteralInt + ): + table.append( + EqualityTypeComparison(from_id=None, to_id=literal_int.id, type_usage=parse.DataTypeUsage( + name=parse.Identifier(name="u32"), + )) + ) + + +def check_types(table: List[TypeComparison]): + found = True + while found: + found = False + for entry in table: + for other in table: + if other.to_id == entry.to_id: + if other.type_usage is None and entry.type_usage is None: + pass + elif other.type_usage is None and entry.type_usage is not None: + other.type_usage = entry.type_usage + found = True + elif other.type_usage is not None and entry.type_usage is None: + entry.type_usage = other.type_usage + found = True + else: + assert entry.type_usage == other.type_usage + if other.from_id == entry.to_id: + # let a = || {4} entry + # let b = a() other + if other.type_usage is None and entry.type_usage is None: + pass + elif other.type_usage is None and entry.type_usage is not None: + if isinstance(other, EqualityTypeComparison): + other.type_usage = entry.type_usage + found = True + elif isinstance(other, FunctionCallTypeComparison): + assert isinstance(entry.type_usage, parse.FunctionTypeUsage), 'non function called' + other.type_usage = entry.type_usage.return_type + found = True + elif other.type_usage is not None and entry.type_usage is None: + if isinstance(other, EqualityTypeComparison): + entry.type_usage = other.type_usage + found = True + elif isinstance(other, FunctionCallTypeComparison): + pass # can't reverse a function + else: + if isinstance(other, EqualityTypeComparison): + assert other.type_usage == entry.type_usage + elif isinstance(other, FunctionCallTypeComparison): + assert isinstance(entry.type_usage, parse.FunctionTypeUsage), 'non function called' + assert other.type_usage == entry.type_usage.return_type + + + + + + + + # if other.to_id == entry.from_id: + # # let a = || {4} other + # # let b = a() entry + # if other.type_usage is None and entry.type_usage is None: + # pass + # elif other.type_usage is None and entry.type_usage is not None: + # if isinstance(entry, EqualityTypeComparison): + # other.type_usage = entry.type_usage + # found = True + # elif isinstance(entry, FunctionCallTypeComparison): + # pass # can't reverse a function + # elif other.type_usage is not None and entry.type_usage is None: + # if isinstance(entry, EqualityTypeComparison): + # entry.type_usage = other.type_usage + # found = True + # elif isinstance(entry, FunctionCallTypeComparison): + # entry.type_usage = other.type_usage.return_type + # found = True + # if other.from_id == entry.from_id and entry.from_id is not None: + # if other.type_usage is None and entry.type_usage is None: + # pass + # elif other.type_usage is None and entry.type_usage is not None: + # other.type_usage = entry.type_usage + # found = True + # elif other.type_usage is not None and entry.type_usage is None: + # entry.type_usage = other.type_usage + # found = True + # else: + # assert entry.type_usage == other.type_usage, f"{entry.from_id} {other.from_id} {entry.type_usage} == {other.type_usage}" diff --git a/boring/typing.py b/boring/typing.py new file mode 100644 index 0000000..2688789 --- /dev/null +++ b/boring/typing.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import List, Dict, Optional, Union + + +from boring import parse + + +Identified = Union[parse.LetStatement, parse.Function] +Environment = Dict[str, Identified] + + +class TypeChecker: + def with_module( + self, env: Environment, table: List[TypeComparison], module: parse.Module + ) -> bool: + pass + + def with_function( + self, env: Environment, table: List[TypeComparison], function: parse.Function + ) -> bool: + pass + + # Skip variable VariableDeclaration + + def with_block( + self, env: Environment, table: List[TypeComparison], block: parse.Block + ) -> bool: + pass + + def with_statement( + self, env: Environment, table: List[TypeComparison], statement: parse.Statement + ) -> bool: + pass + + def with_let_statement( + self, env: Environment, table: List[TypeComparison], let_statement: parse.LetStatement + ) -> bool: + pass + + def with_expression( + self, env: Environment, table: List[TypeComparison], expression: parse.Expression + ) -> bool: + pass + + def with_variable_usage( + self, env: Environment, table: List[TypeComparison], variable_usage: parse.VariableUsage + ) -> bool: + pass + + def with_operation( + self, env: Environment, table: List[TypeComparison], operation: parse.Operation + ) -> bool: + pass + + def with_function_call( + self, env: Environment, table: List[TypeComparison], function_call: parse.FunctionCall + ) -> bool: + pass + + + def with_literal_int( + self, env: Environment, table: List[TypeComparison], literal_int: parse.LiteralInt + ) -> bool: + pass diff --git a/docker-compose.yml b/docker-compose.yml index a929423..4d1cfc5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,9 +1,9 @@ version: "3" services: - boring: - build: . - volumes: - - .:/code/ + # boring: + # build: . + # volumes: + # - .:/code/ boring-python: build: context: . diff --git a/examples/math/main.bl b/examples/math/main.bl index 9427e28..fe6edc6 100644 --- a/examples/math/main.bl +++ b/examples/math/main.bl @@ -1,12 +1,12 @@ -fn add(a, b) { +fn add(a: u32, b: u32): u32 { let foo = 4; - a + b + a + b + foo } -fn subtract(a, b) { +fn subtract(a: u32, b: u32): u32 { a - b } -fn main() { +fn main(): u32 { add(4, subtract(5, 2)) } diff --git a/notes.txt b/notes.txt new file mode 100644 index 0000000..00d75d2 --- /dev/null +++ b/notes.txt @@ -0,0 +1,28 @@ + + + + +# On types +Type Usage != Type Definition + + +type List[T] struct { + +} + + +fn add[T: addable](a: T, b: T): T { + return a + b; +} + +type usages: +List[Int64] +fn(int, int): List[Int64] + + + +@dataclass +class TypeUsage: + result: Identifier # Result of useage - either is the type, or is the return value if it's a function + type_args: List[Type] # Generics + arguments: Optional[List[Type]] # Specified if it is a function, this is how you tell if it's a function diff --git a/src/types.rs b/src/types.rs index 5d6a2ba..7d8a947 100644 --- a/src/types.rs +++ b/src/types.rs @@ -62,14 +62,3 @@ pub enum SpecifiedType { Unknown, Type(Type), } - - -// Env table -// name => ast_id -// -// Type Table -// ast_id => type_id -// -// -// TypeDef Table -// type_id => type_def