added struct definition

This commit is contained in:
Andrew Segavac
2021-06-11 20:59:51 -06:00
parent f05888a817
commit acbaf5f729
6 changed files with 186 additions and 66 deletions

View File

@@ -2,7 +2,25 @@ import sys
from typing import List
from boring.parse import boring_parser, TreeToBoring, pretty_print
from boring.type_checking import TypeChecker, Context
from boring import typedefs
from boring import typedefs, parse
builtins = {
"U8": parse.PrimitiveTypeDeclaration("U8"),
"U16": parse.PrimitiveTypeDeclaration("U16"),
"U32": parse.PrimitiveTypeDeclaration("U32"),
"U64": parse.PrimitiveTypeDeclaration("U64"),
"U128": parse.PrimitiveTypeDeclaration("U128"),
"I8": parse.PrimitiveTypeDeclaration("I8"),
"I16": parse.PrimitiveTypeDeclaration("I16"),
"I32": parse.PrimitiveTypeDeclaration("I32"),
"I64": parse.PrimitiveTypeDeclaration("I64"),
"I128": parse.PrimitiveTypeDeclaration("I128"),
"F32": parse.PrimitiveTypeDeclaration("F32"),
"F64": parse.PrimitiveTypeDeclaration("F64"),
"F128": parse.PrimitiveTypeDeclaration("F128"),
"()": parse.PrimitiveTypeDeclaration("()"), # Unit
"!": parse.PrimitiveTypeDeclaration("!"), # Never
}
if __name__ == "__main__":
with open(sys.argv[1]) as f:
@@ -11,7 +29,7 @@ if __name__ == "__main__":
result = TreeToBoring().transform(tree)
# pretty_print(result)
type_checker = TypeChecker()
while type_checker.with_module(Context({}, typedefs.builtins, None), result):
while type_checker.with_module(Context(builtins, None), result):
print("loop")
# type_checker.with_module({}, result)
pretty_print(result)

View File

@@ -1,6 +1,6 @@
import sys
import enum
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict
from dataclasses import dataclass, field
from lark import Lark, Transformer
@@ -33,7 +33,7 @@ NEVER_TYPE = "!"
class FunctionTypeUsage:
arguments: List[
"TypeUsage"
] # Specified if it is a function, this is how you tell if it's a function
]
return_type: "TypeUsage"
@@ -98,7 +98,15 @@ class ReturnStatement:
@dataclass
class Expression:
expression: Union[LiteralInt, LiteralFloat, FunctionCall, "Block", ReturnStatement, VariableUsage, Operation]
expression: Union[
LiteralInt,
LiteralFloat,
FunctionCall,
"Block",
ReturnStatement,
VariableUsage,
Operation,
]
type: TypeUsage
@@ -140,9 +148,24 @@ class Function:
type: TypeUsage
@dataclass
class PrimitiveTypeDeclaration:
name: str
@dataclass
class StructTypeDeclaration:
name: str
fields: Dict[str, TypeUsage]
TypeDeclaration = Union[StructTypeDeclaration, PrimitiveTypeDeclaration]
@dataclass
class Module:
functions: List[Function]
types: List[TypeDeclaration]
boring_grammar = r"""
@@ -212,7 +235,14 @@ boring_grammar = r"""
function : function_with_return
| function_without_return
module : (function)*
struct_definition_field : identifier ":" type_usage
struct_type_declaration : "type" identifier "struct" "{" (struct_definition_field ",")* "}"
type_declaration : struct_type_declaration
module : (function|type_declaration)*
%import common.CNAME
%import common.SIGNED_INT
@@ -260,7 +290,9 @@ class TreeToBoring(Transformer):
def return_statement(self, return_expression) -> ReturnStatement:
(return_expression,) = return_expression
return ReturnStatement(source=return_expression, type=DataTypeUsage(name=NEVER_TYPE))
return ReturnStatement(
source=return_expression, type=DataTypeUsage(name=NEVER_TYPE)
)
def function_call(self, call) -> FunctionCall:
return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage())
@@ -371,8 +403,28 @@ class TreeToBoring(Transformer):
(function,) = function
return function
def module(self, functions) -> Module:
return Module(functions=functions)
def struct_definition_field(self, struct_definition_field):
(field, type_usage) = struct_definition_field
return (field, type_usage)
def struct_type_declaration(self, struct_type_declaration) -> StructTypeDeclaration:
name = struct_type_declaration[0]
fields = {key: value for (key, value) in struct_type_declaration[1:]}
return StructTypeDeclaration(name=name, fields=fields)
def type_declaration(self, type_declaration):
(type_declaration,) = type_declaration
return type_declaration
def module(self, module_items) -> Module:
functions = []
types = []
for item in module_items:
if isinstance(item, Function):
functions.append(item)
else:
types.append(item)
return Module(functions=functions, types=types)
boring_parser = Lark(boring_grammar, start="module", lexer="standard")

View File

@@ -1,32 +1,34 @@
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from typing import List, Dict, Optional, Union, Tuple
from boring import parse, typedefs
Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration]
Identified = Union[
parse.LetStatement, parse.Function, parse.VariableDeclaration, parse.TypeDeclaration
]
Environment = Dict[str, Identified]
TypeEnvironment = Dict[str, typedefs.TypeDef]
@dataclass
class Context:
environment: Environment
type_environment: TypeEnvironment
current_function: Optional[parse.Function]
def copy(self):
return Context(self.environment.copy(), self.type_environment.copy(), self.current_function)
return Context(self.environment.copy(), self.current_function)
def unify(ctx: Context, first, second) -> bool:
changed: bool
result, changed = type_compare(ctx, first.type, second.type)
first.type = result
second.type = result
return changed
def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
def type_compare(ctx: Context, first: parse.TypeUsage, second: parse.TypeUsage) -> Tuple[parse.TypeUsage, bool]:
print(first, second)
if isinstance(first, parse.UnknownTypeUsage):
if not isinstance(second, parse.UnknownTypeUsage):
@@ -41,8 +43,18 @@ def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
second, parse.DataTypeUsage
):
assert second == first
assert first.name in ctx.type_environment
assert second.name in ctx.type_environment
assert first.name in ctx.environment # TODO: validate that it is a type
assert isinstance(
ctx.environment[first.name], parse.StructTypeDeclaration
) or isinstance(
ctx.environment[first.name], parse.PrimitiveTypeDeclaration
)
assert second.name in ctx.environment
assert isinstance(
ctx.environment[second.name], parse.StructTypeDeclaration
) or isinstance(
ctx.environment[second.name], parse.PrimitiveTypeDeclaration
)
return first, False
elif isinstance(first, parse.FunctionTypeUsage) and isinstance(
second, parse.FunctionTypeUsage
@@ -64,8 +76,22 @@ def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
assert False, f"mismatched types {first}, {second}"
def assert_exists(ctx: Context, type: parse.TypeUsage):
if isinstance(type, parse.DataTypeUsage):
assert type.name in ctx.environment
elif isinstance(type, parse.FunctionTypeUsage):
assert_exists(ctx, type.return_type)
for argument in type.arguments:
assert_exists(ctx, argument)
class TypeChecker:
def with_module(self, ctx: Context, module: parse.Module) -> bool:
for type_declaration in module.types:
ctx.environment[type_declaration.name] = type_declaration
for type_declaration in module.types:
if isinstance(type_declaration, parse.StructTypeDeclaration):
for name, field in type_declaration.fields.items():
assert_exists(ctx, field)
for function in module.functions:
ctx.environment[function.name] = function
changed = False
@@ -83,8 +109,13 @@ class TypeChecker:
changed = self.with_block(function_ctx, function.block)
if not (isinstance(function.block.type, parse.DataTypeUsage) and function.block.type.name == parse.NEVER_TYPE):
type, compare_changed = type_compare(function_ctx, function.block.type, function.type.return_type)
if not (
isinstance(function.block.type, parse.DataTypeUsage)
and function.block.type.name == parse.NEVER_TYPE
):
type, compare_changed = type_compare(
function_ctx, function.block.type, function.type.return_type
)
function.block.type = type
function.type.return_type = type
if compare_changed is True:
@@ -105,23 +136,15 @@ class TypeChecker:
if isinstance(final, parse.LetStatement):
if isinstance(block.type, parse.UnknownTypeUsage):
changed = True
block.type = parse.DataTypeUsage(
name=parse.UNIT_TYPE
)
block.type = parse.DataTypeUsage(name=parse.UNIT_TYPE)
else:
assert block.type == parse.DataTypeUsage(
name=parse.UNIT_TYPE
)
assert block.type == parse.DataTypeUsage(name=parse.UNIT_TYPE)
elif isinstance(final, parse.ReturnStatement):
if isinstance(block.type, parse.UnknownTypeUsage):
changed = True
block.type = parse.DataTypeUsage(
name=parse.NEVER_TYPE
)
block.type = parse.DataTypeUsage(name=parse.NEVER_TYPE)
else:
assert block.type == parse.DataTypeUsage(
name=parse.NEVER_TYPE
)
assert block.type == parse.DataTypeUsage(name=parse.NEVER_TYPE)
elif isinstance(final, parse.Expression):
if unify(block_ctx, final, block):
changed = True
@@ -156,7 +179,11 @@ class TypeChecker:
changed = False
if self.with_expression(ctx, assignment_statement.expression):
changed = True
if unify(ctx, assignment_statement, ctx.environment[assignment_statement.variable_name]):
if unify(
ctx,
assignment_statement,
ctx.environment[assignment_statement.variable_name],
):
changed = True
if unify(ctx, assignment_statement, assignment_statement.expression):
changed = True
@@ -168,8 +195,15 @@ class TypeChecker:
changed = self.with_expression(ctx, return_statement.source)
# Doesn't match on an unreachable return
if not (isinstance(return_statement.source.type, parse.DataTypeUsage) and return_statement.source.type.name == parse.NEVER_TYPE):
type, compare_changed = type_compare(ctx, return_statement.source.type, ctx.current_function.type.return_type)
if not (
isinstance(return_statement.source.type, parse.DataTypeUsage)
and return_statement.source.type.name == parse.NEVER_TYPE
):
assert isinstance(ctx.current_function, parse.Function)
assert isinstance(ctx.current_function.type, parse.FunctionTypeUsage)
type, compare_changed = type_compare(
ctx, return_statement.source.type, ctx.current_function.type.return_type
)
return_statement.source.type = type
ctx.current_function.type.return_type = type
if compare_changed is True:
@@ -245,6 +279,7 @@ class TypeChecker:
if self.with_expression(ctx, argument):
changed = True
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
return_type, return_changed = type_compare(
ctx, function_call.type, function_call.source.type.return_type
)
@@ -265,14 +300,18 @@ class TypeChecker:
changed = True
return changed
def with_literal_float(self, ctx: Context, literal_float: parse.LiteralFloat) -> bool:
def with_literal_float(
self, ctx: Context, literal_float: parse.LiteralFloat
) -> bool:
floats = ["F32", "F64", "F128"]
if not isinstance(literal_float.type, parse.UnknownTypeUsage):
assert isinstance(literal_float.type, parse.DataTypeUsage)
assert literal_float.type.name in floats, f"{literal_float.type}"
return False
def with_literal_int(self, ctx: Context, literal_int: parse.LiteralInt) -> bool:
ints = ["I8", "I16", "I32", "I64", "I128", "U8", "U16", "U32", "U64", "U128"]
if not isinstance(literal_int.type, parse.UnknownTypeUsage):
assert isinstance(literal_int.type, parse.DataTypeUsage)
assert literal_int.type.name in ints, f"{literal_int.type}"
return False

View File

@@ -4,22 +4,22 @@ from typing import List, Dict, Optional, Union
class IntBitness(enum.Enum):
X8 = 'X8'
X16 = 'X16'
X32 = 'X32'
X64 = 'X64'
X128 = 'X128'
X8 = "X8"
X16 = "X16"
X32 = "X32"
X64 = "X64"
X128 = "X128"
class Signedness(enum.Enum):
Signed = 'Signed'
Unsigned = 'Unsigned'
Signed = "Signed"
Unsigned = "Unsigned"
class FloatBitness(enum.Enum):
X32 = 'X32'
X64 = 'X64'
X128 = 'X128'
X32 = "X32"
X64 = "X64"
X128 = "X128"
@dataclass
@@ -39,6 +39,11 @@ class FunctionTypeDef:
return_type: "TypeDef"
@dataclass
class StructTypeDef:
fields: Dict[str, "TypeDef"]
@dataclass
class UnitTypeDef:
pass
@@ -49,26 +54,25 @@ class NeverTypeDef:
pass
TypeDef = Union[IntTypeDef, FloatTypeDef, FunctionTypeDef, UnitTypeDef, NeverTypeDef]
TypeDef = Union[
IntTypeDef, FloatTypeDef, FunctionTypeDef, StructTypeDef, UnitTypeDef, NeverTypeDef
]
builtins: Dict[str, TypeDef] = {
'U8': IntTypeDef(Signedness.Unsigned, IntBitness.X8),
'U16': IntTypeDef(Signedness.Unsigned, IntBitness.X16),
'U32': IntTypeDef(Signedness.Unsigned, IntBitness.X32),
'U64': IntTypeDef(Signedness.Unsigned, IntBitness.X64),
'U128': IntTypeDef(Signedness.Unsigned, IntBitness.X128),
'I8': IntTypeDef(Signedness.Signed, IntBitness.X8),
'I16': IntTypeDef(Signedness.Signed, IntBitness.X16),
'I32': IntTypeDef(Signedness.Signed, IntBitness.X32),
'I64': IntTypeDef(Signedness.Signed, IntBitness.X64),
'I128': IntTypeDef(Signedness.Signed, IntBitness.X128),
'F32': FloatTypeDef(FloatBitness.X32),
'F64': FloatTypeDef(FloatBitness.X64),
'F128': FloatTypeDef(FloatBitness.X128),
'()': UnitTypeDef(),
'!': NeverTypeDef(),
"U8": IntTypeDef(Signedness.Unsigned, IntBitness.X8),
"U16": IntTypeDef(Signedness.Unsigned, IntBitness.X16),
"U32": IntTypeDef(Signedness.Unsigned, IntBitness.X32),
"U64": IntTypeDef(Signedness.Unsigned, IntBitness.X64),
"U128": IntTypeDef(Signedness.Unsigned, IntBitness.X128),
"I8": IntTypeDef(Signedness.Signed, IntBitness.X8),
"I16": IntTypeDef(Signedness.Signed, IntBitness.X16),
"I32": IntTypeDef(Signedness.Signed, IntBitness.X32),
"I64": IntTypeDef(Signedness.Signed, IntBitness.X64),
"I128": IntTypeDef(Signedness.Signed, IntBitness.X128),
"F32": FloatTypeDef(FloatBitness.X32),
"F64": FloatTypeDef(FloatBitness.X64),
"F128": FloatTypeDef(FloatBitness.X128),
"()": UnitTypeDef(),
"!": NeverTypeDef(),
}

View File

@@ -31,3 +31,8 @@ fn unit_function() {
fn main(): I32 {
add(4, subtract(5, 2))
}
type User struct {
id: U64,
}

View File

@@ -38,14 +38,16 @@ TODO:
* ~Return keyword~
* ~Normal assignment~
* Structs
* Define
* ~Define~
* Literal
* Getter
* Setter
* Generics
* Enums
* Methods
* Traits
* Arrays
* Strings
* Traits
* Lambdas
* Async
* Imports