added struct definition

This commit is contained in:
Andrew Segavac
2021-06-11 20:59:51 -06:00
parent f05888a817
commit acbaf5f729
6 changed files with 186 additions and 66 deletions

View File

@@ -2,7 +2,25 @@ import sys
from typing import List from typing import List
from boring.parse import boring_parser, TreeToBoring, pretty_print from boring.parse import boring_parser, TreeToBoring, pretty_print
from boring.type_checking import TypeChecker, Context from boring.type_checking import TypeChecker, Context
from boring import typedefs from boring import typedefs, parse
builtins = {
"U8": parse.PrimitiveTypeDeclaration("U8"),
"U16": parse.PrimitiveTypeDeclaration("U16"),
"U32": parse.PrimitiveTypeDeclaration("U32"),
"U64": parse.PrimitiveTypeDeclaration("U64"),
"U128": parse.PrimitiveTypeDeclaration("U128"),
"I8": parse.PrimitiveTypeDeclaration("I8"),
"I16": parse.PrimitiveTypeDeclaration("I16"),
"I32": parse.PrimitiveTypeDeclaration("I32"),
"I64": parse.PrimitiveTypeDeclaration("I64"),
"I128": parse.PrimitiveTypeDeclaration("I128"),
"F32": parse.PrimitiveTypeDeclaration("F32"),
"F64": parse.PrimitiveTypeDeclaration("F64"),
"F128": parse.PrimitiveTypeDeclaration("F128"),
"()": parse.PrimitiveTypeDeclaration("()"), # Unit
"!": parse.PrimitiveTypeDeclaration("!"), # Never
}
if __name__ == "__main__": if __name__ == "__main__":
with open(sys.argv[1]) as f: with open(sys.argv[1]) as f:
@@ -11,7 +29,7 @@ if __name__ == "__main__":
result = TreeToBoring().transform(tree) result = TreeToBoring().transform(tree)
# pretty_print(result) # pretty_print(result)
type_checker = TypeChecker() type_checker = TypeChecker()
while type_checker.with_module(Context({}, typedefs.builtins, None), result): while type_checker.with_module(Context(builtins, None), result):
print("loop") print("loop")
# type_checker.with_module({}, result) # type_checker.with_module({}, result)
pretty_print(result) pretty_print(result)

View File

@@ -1,6 +1,6 @@
import sys import sys
import enum import enum
from typing import Union, List, Optional from typing import Union, List, Optional, Dict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lark import Lark, Transformer from lark import Lark, Transformer
@@ -33,7 +33,7 @@ NEVER_TYPE = "!"
class FunctionTypeUsage: class FunctionTypeUsage:
arguments: List[ arguments: List[
"TypeUsage" "TypeUsage"
] # Specified if it is a function, this is how you tell if it's a function ]
return_type: "TypeUsage" return_type: "TypeUsage"
@@ -98,7 +98,15 @@ class ReturnStatement:
@dataclass @dataclass
class Expression: class Expression:
expression: Union[LiteralInt, LiteralFloat, FunctionCall, "Block", ReturnStatement, VariableUsage, Operation] expression: Union[
LiteralInt,
LiteralFloat,
FunctionCall,
"Block",
ReturnStatement,
VariableUsage,
Operation,
]
type: TypeUsage type: TypeUsage
@@ -140,9 +148,24 @@ class Function:
type: TypeUsage type: TypeUsage
@dataclass
class PrimitiveTypeDeclaration:
name: str
@dataclass
class StructTypeDeclaration:
name: str
fields: Dict[str, TypeUsage]
TypeDeclaration = Union[StructTypeDeclaration, PrimitiveTypeDeclaration]
@dataclass @dataclass
class Module: class Module:
functions: List[Function] functions: List[Function]
types: List[TypeDeclaration]
boring_grammar = r""" boring_grammar = r"""
@@ -212,7 +235,14 @@ boring_grammar = r"""
function : function_with_return function : function_with_return
| function_without_return | function_without_return
module : (function)*
struct_definition_field : identifier ":" type_usage
struct_type_declaration : "type" identifier "struct" "{" (struct_definition_field ",")* "}"
type_declaration : struct_type_declaration
module : (function|type_declaration)*
%import common.CNAME %import common.CNAME
%import common.SIGNED_INT %import common.SIGNED_INT
@@ -260,7 +290,9 @@ class TreeToBoring(Transformer):
def return_statement(self, return_expression) -> ReturnStatement: def return_statement(self, return_expression) -> ReturnStatement:
(return_expression,) = return_expression (return_expression,) = return_expression
return ReturnStatement(source=return_expression, type=DataTypeUsage(name=NEVER_TYPE)) return ReturnStatement(
source=return_expression, type=DataTypeUsage(name=NEVER_TYPE)
)
def function_call(self, call) -> FunctionCall: def function_call(self, call) -> FunctionCall:
return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage()) return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage())
@@ -371,8 +403,28 @@ class TreeToBoring(Transformer):
(function,) = function (function,) = function
return function return function
def module(self, functions) -> Module: def struct_definition_field(self, struct_definition_field):
return Module(functions=functions) (field, type_usage) = struct_definition_field
return (field, type_usage)
def struct_type_declaration(self, struct_type_declaration) -> StructTypeDeclaration:
name = struct_type_declaration[0]
fields = {key: value for (key, value) in struct_type_declaration[1:]}
return StructTypeDeclaration(name=name, fields=fields)
def type_declaration(self, type_declaration):
(type_declaration,) = type_declaration
return type_declaration
def module(self, module_items) -> Module:
functions = []
types = []
for item in module_items:
if isinstance(item, Function):
functions.append(item)
else:
types.append(item)
return Module(functions=functions, types=types)
boring_parser = Lark(boring_grammar, start="module", lexer="standard") boring_parser = Lark(boring_grammar, start="module", lexer="standard")

View File

@@ -1,32 +1,34 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict, Optional, Union from typing import List, Dict, Optional, Union, Tuple
from boring import parse, typedefs from boring import parse, typedefs
Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration] Identified = Union[
parse.LetStatement, parse.Function, parse.VariableDeclaration, parse.TypeDeclaration
]
Environment = Dict[str, Identified] Environment = Dict[str, Identified]
TypeEnvironment = Dict[str, typedefs.TypeDef]
@dataclass @dataclass
class Context: class Context:
environment: Environment environment: Environment
type_environment: TypeEnvironment
current_function: Optional[parse.Function] current_function: Optional[parse.Function]
def copy(self): def copy(self):
return Context(self.environment.copy(), self.type_environment.copy(), self.current_function) return Context(self.environment.copy(), self.current_function)
def unify(ctx: Context, first, second) -> bool: def unify(ctx: Context, first, second) -> bool:
changed: bool
result, changed = type_compare(ctx, first.type, second.type) result, changed = type_compare(ctx, first.type, second.type)
first.type = result first.type = result
second.type = result second.type = result
return changed return changed
def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool): def type_compare(ctx: Context, first: parse.TypeUsage, second: parse.TypeUsage) -> Tuple[parse.TypeUsage, bool]:
print(first, second) print(first, second)
if isinstance(first, parse.UnknownTypeUsage): if isinstance(first, parse.UnknownTypeUsage):
if not isinstance(second, parse.UnknownTypeUsage): if not isinstance(second, parse.UnknownTypeUsage):
@@ -41,8 +43,18 @@ def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
second, parse.DataTypeUsage second, parse.DataTypeUsage
): ):
assert second == first assert second == first
assert first.name in ctx.type_environment assert first.name in ctx.environment # TODO: validate that it is a type
assert second.name in ctx.type_environment assert isinstance(
ctx.environment[first.name], parse.StructTypeDeclaration
) or isinstance(
ctx.environment[first.name], parse.PrimitiveTypeDeclaration
)
assert second.name in ctx.environment
assert isinstance(
ctx.environment[second.name], parse.StructTypeDeclaration
) or isinstance(
ctx.environment[second.name], parse.PrimitiveTypeDeclaration
)
return first, False return first, False
elif isinstance(first, parse.FunctionTypeUsage) and isinstance( elif isinstance(first, parse.FunctionTypeUsage) and isinstance(
second, parse.FunctionTypeUsage second, parse.FunctionTypeUsage
@@ -64,8 +76,22 @@ def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
assert False, f"mismatched types {first}, {second}" assert False, f"mismatched types {first}, {second}"
def assert_exists(ctx: Context, type: parse.TypeUsage):
if isinstance(type, parse.DataTypeUsage):
assert type.name in ctx.environment
elif isinstance(type, parse.FunctionTypeUsage):
assert_exists(ctx, type.return_type)
for argument in type.arguments:
assert_exists(ctx, argument)
class TypeChecker: class TypeChecker:
def with_module(self, ctx: Context, module: parse.Module) -> bool: def with_module(self, ctx: Context, module: parse.Module) -> bool:
for type_declaration in module.types:
ctx.environment[type_declaration.name] = type_declaration
for type_declaration in module.types:
if isinstance(type_declaration, parse.StructTypeDeclaration):
for name, field in type_declaration.fields.items():
assert_exists(ctx, field)
for function in module.functions: for function in module.functions:
ctx.environment[function.name] = function ctx.environment[function.name] = function
changed = False changed = False
@@ -83,8 +109,13 @@ class TypeChecker:
changed = self.with_block(function_ctx, function.block) changed = self.with_block(function_ctx, function.block)
if not (isinstance(function.block.type, parse.DataTypeUsage) and function.block.type.name == parse.NEVER_TYPE): if not (
type, compare_changed = type_compare(function_ctx, function.block.type, function.type.return_type) isinstance(function.block.type, parse.DataTypeUsage)
and function.block.type.name == parse.NEVER_TYPE
):
type, compare_changed = type_compare(
function_ctx, function.block.type, function.type.return_type
)
function.block.type = type function.block.type = type
function.type.return_type = type function.type.return_type = type
if compare_changed is True: if compare_changed is True:
@@ -105,23 +136,15 @@ class TypeChecker:
if isinstance(final, parse.LetStatement): if isinstance(final, parse.LetStatement):
if isinstance(block.type, parse.UnknownTypeUsage): if isinstance(block.type, parse.UnknownTypeUsage):
changed = True changed = True
block.type = parse.DataTypeUsage( block.type = parse.DataTypeUsage(name=parse.UNIT_TYPE)
name=parse.UNIT_TYPE
)
else: else:
assert block.type == parse.DataTypeUsage( assert block.type == parse.DataTypeUsage(name=parse.UNIT_TYPE)
name=parse.UNIT_TYPE
)
elif isinstance(final, parse.ReturnStatement): elif isinstance(final, parse.ReturnStatement):
if isinstance(block.type, parse.UnknownTypeUsage): if isinstance(block.type, parse.UnknownTypeUsage):
changed = True changed = True
block.type = parse.DataTypeUsage( block.type = parse.DataTypeUsage(name=parse.NEVER_TYPE)
name=parse.NEVER_TYPE
)
else: else:
assert block.type == parse.DataTypeUsage( assert block.type == parse.DataTypeUsage(name=parse.NEVER_TYPE)
name=parse.NEVER_TYPE
)
elif isinstance(final, parse.Expression): elif isinstance(final, parse.Expression):
if unify(block_ctx, final, block): if unify(block_ctx, final, block):
changed = True changed = True
@@ -156,7 +179,11 @@ class TypeChecker:
changed = False changed = False
if self.with_expression(ctx, assignment_statement.expression): if self.with_expression(ctx, assignment_statement.expression):
changed = True changed = True
if unify(ctx, assignment_statement, ctx.environment[assignment_statement.variable_name]): if unify(
ctx,
assignment_statement,
ctx.environment[assignment_statement.variable_name],
):
changed = True changed = True
if unify(ctx, assignment_statement, assignment_statement.expression): if unify(ctx, assignment_statement, assignment_statement.expression):
changed = True changed = True
@@ -168,8 +195,15 @@ class TypeChecker:
changed = self.with_expression(ctx, return_statement.source) changed = self.with_expression(ctx, return_statement.source)
# Doesn't match on an unreachable return # Doesn't match on an unreachable return
if not (isinstance(return_statement.source.type, parse.DataTypeUsage) and return_statement.source.type.name == parse.NEVER_TYPE): if not (
type, compare_changed = type_compare(ctx, return_statement.source.type, ctx.current_function.type.return_type) isinstance(return_statement.source.type, parse.DataTypeUsage)
and return_statement.source.type.name == parse.NEVER_TYPE
):
assert isinstance(ctx.current_function, parse.Function)
assert isinstance(ctx.current_function.type, parse.FunctionTypeUsage)
type, compare_changed = type_compare(
ctx, return_statement.source.type, ctx.current_function.type.return_type
)
return_statement.source.type = type return_statement.source.type = type
ctx.current_function.type.return_type = type ctx.current_function.type.return_type = type
if compare_changed is True: if compare_changed is True:
@@ -245,6 +279,7 @@ class TypeChecker:
if self.with_expression(ctx, argument): if self.with_expression(ctx, argument):
changed = True changed = True
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
return_type, return_changed = type_compare( return_type, return_changed = type_compare(
ctx, function_call.type, function_call.source.type.return_type ctx, function_call.type, function_call.source.type.return_type
) )
@@ -265,14 +300,18 @@ class TypeChecker:
changed = True changed = True
return changed return changed
def with_literal_float(self, ctx: Context, literal_float: parse.LiteralFloat) -> bool: def with_literal_float(
self, ctx: Context, literal_float: parse.LiteralFloat
) -> bool:
floats = ["F32", "F64", "F128"] floats = ["F32", "F64", "F128"]
if not isinstance(literal_float.type, parse.UnknownTypeUsage): if not isinstance(literal_float.type, parse.UnknownTypeUsage):
assert isinstance(literal_float.type, parse.DataTypeUsage)
assert literal_float.type.name in floats, f"{literal_float.type}" assert literal_float.type.name in floats, f"{literal_float.type}"
return False return False
def with_literal_int(self, ctx: Context, literal_int: parse.LiteralInt) -> bool: def with_literal_int(self, ctx: Context, literal_int: parse.LiteralInt) -> bool:
ints = ["I8", "I16", "I32", "I64", "I128", "U8", "U16", "U32", "U64", "U128"] ints = ["I8", "I16", "I32", "I64", "I128", "U8", "U16", "U32", "U64", "U128"]
if not isinstance(literal_int.type, parse.UnknownTypeUsage): if not isinstance(literal_int.type, parse.UnknownTypeUsage):
assert isinstance(literal_int.type, parse.DataTypeUsage)
assert literal_int.type.name in ints, f"{literal_int.type}" assert literal_int.type.name in ints, f"{literal_int.type}"
return False return False

View File

@@ -4,22 +4,22 @@ from typing import List, Dict, Optional, Union
class IntBitness(enum.Enum): class IntBitness(enum.Enum):
X8 = 'X8' X8 = "X8"
X16 = 'X16' X16 = "X16"
X32 = 'X32' X32 = "X32"
X64 = 'X64' X64 = "X64"
X128 = 'X128' X128 = "X128"
class Signedness(enum.Enum): class Signedness(enum.Enum):
Signed = 'Signed' Signed = "Signed"
Unsigned = 'Unsigned' Unsigned = "Unsigned"
class FloatBitness(enum.Enum): class FloatBitness(enum.Enum):
X32 = 'X32' X32 = "X32"
X64 = 'X64' X64 = "X64"
X128 = 'X128' X128 = "X128"
@dataclass @dataclass
@@ -39,6 +39,11 @@ class FunctionTypeDef:
return_type: "TypeDef" return_type: "TypeDef"
@dataclass
class StructTypeDef:
fields: Dict[str, "TypeDef"]
@dataclass @dataclass
class UnitTypeDef: class UnitTypeDef:
pass pass
@@ -49,26 +54,25 @@ class NeverTypeDef:
pass pass
TypeDef = Union[IntTypeDef, FloatTypeDef, FunctionTypeDef, UnitTypeDef, NeverTypeDef] TypeDef = Union[
IntTypeDef, FloatTypeDef, FunctionTypeDef, StructTypeDef, UnitTypeDef, NeverTypeDef
]
builtins: Dict[str, TypeDef] = { builtins: Dict[str, TypeDef] = {
'U8': IntTypeDef(Signedness.Unsigned, IntBitness.X8), "U8": IntTypeDef(Signedness.Unsigned, IntBitness.X8),
'U16': IntTypeDef(Signedness.Unsigned, IntBitness.X16), "U16": IntTypeDef(Signedness.Unsigned, IntBitness.X16),
'U32': IntTypeDef(Signedness.Unsigned, IntBitness.X32), "U32": IntTypeDef(Signedness.Unsigned, IntBitness.X32),
'U64': IntTypeDef(Signedness.Unsigned, IntBitness.X64), "U64": IntTypeDef(Signedness.Unsigned, IntBitness.X64),
'U128': IntTypeDef(Signedness.Unsigned, IntBitness.X128), "U128": IntTypeDef(Signedness.Unsigned, IntBitness.X128),
"I8": IntTypeDef(Signedness.Signed, IntBitness.X8),
'I8': IntTypeDef(Signedness.Signed, IntBitness.X8), "I16": IntTypeDef(Signedness.Signed, IntBitness.X16),
'I16': IntTypeDef(Signedness.Signed, IntBitness.X16), "I32": IntTypeDef(Signedness.Signed, IntBitness.X32),
'I32': IntTypeDef(Signedness.Signed, IntBitness.X32), "I64": IntTypeDef(Signedness.Signed, IntBitness.X64),
'I64': IntTypeDef(Signedness.Signed, IntBitness.X64), "I128": IntTypeDef(Signedness.Signed, IntBitness.X128),
'I128': IntTypeDef(Signedness.Signed, IntBitness.X128), "F32": FloatTypeDef(FloatBitness.X32),
"F64": FloatTypeDef(FloatBitness.X64),
'F32': FloatTypeDef(FloatBitness.X32), "F128": FloatTypeDef(FloatBitness.X128),
'F64': FloatTypeDef(FloatBitness.X64), "()": UnitTypeDef(),
'F128': FloatTypeDef(FloatBitness.X128), "!": NeverTypeDef(),
'()': UnitTypeDef(),
'!': NeverTypeDef(),
} }

View File

@@ -31,3 +31,8 @@ fn unit_function() {
fn main(): I32 { fn main(): I32 {
add(4, subtract(5, 2)) add(4, subtract(5, 2))
} }
type User struct {
id: U64,
}

View File

@@ -38,14 +38,16 @@ TODO:
* ~Return keyword~ * ~Return keyword~
* ~Normal assignment~ * ~Normal assignment~
* Structs * Structs
* Define * ~Define~
* Literal * Literal
* Getter * Getter
* Setter * Setter
* Generics * Generics
* Enums * Enums
* Methods * Methods
* Traits
* Arrays * Arrays
* Strings * Strings
* Traits * Lambdas
* Async * Async
* Imports