added struct definition
This commit is contained in:
@@ -2,7 +2,25 @@ import sys
|
||||
from typing import List
|
||||
from boring.parse import boring_parser, TreeToBoring, pretty_print
|
||||
from boring.type_checking import TypeChecker, Context
|
||||
from boring import typedefs
|
||||
from boring import typedefs, parse
|
||||
|
||||
builtins = {
|
||||
"U8": parse.PrimitiveTypeDeclaration("U8"),
|
||||
"U16": parse.PrimitiveTypeDeclaration("U16"),
|
||||
"U32": parse.PrimitiveTypeDeclaration("U32"),
|
||||
"U64": parse.PrimitiveTypeDeclaration("U64"),
|
||||
"U128": parse.PrimitiveTypeDeclaration("U128"),
|
||||
"I8": parse.PrimitiveTypeDeclaration("I8"),
|
||||
"I16": parse.PrimitiveTypeDeclaration("I16"),
|
||||
"I32": parse.PrimitiveTypeDeclaration("I32"),
|
||||
"I64": parse.PrimitiveTypeDeclaration("I64"),
|
||||
"I128": parse.PrimitiveTypeDeclaration("I128"),
|
||||
"F32": parse.PrimitiveTypeDeclaration("F32"),
|
||||
"F64": parse.PrimitiveTypeDeclaration("F64"),
|
||||
"F128": parse.PrimitiveTypeDeclaration("F128"),
|
||||
"()": parse.PrimitiveTypeDeclaration("()"), # Unit
|
||||
"!": parse.PrimitiveTypeDeclaration("!"), # Never
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open(sys.argv[1]) as f:
|
||||
@@ -11,7 +29,7 @@ if __name__ == "__main__":
|
||||
result = TreeToBoring().transform(tree)
|
||||
# pretty_print(result)
|
||||
type_checker = TypeChecker()
|
||||
while type_checker.with_module(Context({}, typedefs.builtins, None), result):
|
||||
while type_checker.with_module(Context(builtins, None), result):
|
||||
print("loop")
|
||||
# type_checker.with_module({}, result)
|
||||
pretty_print(result)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import sys
|
||||
import enum
|
||||
from typing import Union, List, Optional
|
||||
from typing import Union, List, Optional, Dict
|
||||
from dataclasses import dataclass, field
|
||||
from lark import Lark, Transformer
|
||||
|
||||
@@ -33,7 +33,7 @@ NEVER_TYPE = "!"
|
||||
class FunctionTypeUsage:
|
||||
arguments: List[
|
||||
"TypeUsage"
|
||||
] # Specified if it is a function, this is how you tell if it's a function
|
||||
]
|
||||
return_type: "TypeUsage"
|
||||
|
||||
|
||||
@@ -98,7 +98,15 @@ class ReturnStatement:
|
||||
|
||||
@dataclass
|
||||
class Expression:
|
||||
expression: Union[LiteralInt, LiteralFloat, FunctionCall, "Block", ReturnStatement, VariableUsage, Operation]
|
||||
expression: Union[
|
||||
LiteralInt,
|
||||
LiteralFloat,
|
||||
FunctionCall,
|
||||
"Block",
|
||||
ReturnStatement,
|
||||
VariableUsage,
|
||||
Operation,
|
||||
]
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@@ -140,9 +148,24 @@ class Function:
|
||||
type: TypeUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrimitiveTypeDeclaration:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructTypeDeclaration:
|
||||
name: str
|
||||
fields: Dict[str, TypeUsage]
|
||||
|
||||
|
||||
TypeDeclaration = Union[StructTypeDeclaration, PrimitiveTypeDeclaration]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module:
|
||||
functions: List[Function]
|
||||
types: List[TypeDeclaration]
|
||||
|
||||
|
||||
boring_grammar = r"""
|
||||
@@ -212,7 +235,14 @@ boring_grammar = r"""
|
||||
function : function_with_return
|
||||
| function_without_return
|
||||
|
||||
module : (function)*
|
||||
|
||||
struct_definition_field : identifier ":" type_usage
|
||||
|
||||
struct_type_declaration : "type" identifier "struct" "{" (struct_definition_field ",")* "}"
|
||||
|
||||
type_declaration : struct_type_declaration
|
||||
|
||||
module : (function|type_declaration)*
|
||||
|
||||
%import common.CNAME
|
||||
%import common.SIGNED_INT
|
||||
@@ -260,7 +290,9 @@ class TreeToBoring(Transformer):
|
||||
|
||||
def return_statement(self, return_expression) -> ReturnStatement:
|
||||
(return_expression,) = return_expression
|
||||
return ReturnStatement(source=return_expression, type=DataTypeUsage(name=NEVER_TYPE))
|
||||
return ReturnStatement(
|
||||
source=return_expression, type=DataTypeUsage(name=NEVER_TYPE)
|
||||
)
|
||||
|
||||
def function_call(self, call) -> FunctionCall:
|
||||
return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage())
|
||||
@@ -371,8 +403,28 @@ class TreeToBoring(Transformer):
|
||||
(function,) = function
|
||||
return function
|
||||
|
||||
def module(self, functions) -> Module:
|
||||
return Module(functions=functions)
|
||||
def struct_definition_field(self, struct_definition_field):
|
||||
(field, type_usage) = struct_definition_field
|
||||
return (field, type_usage)
|
||||
|
||||
def struct_type_declaration(self, struct_type_declaration) -> StructTypeDeclaration:
|
||||
name = struct_type_declaration[0]
|
||||
fields = {key: value for (key, value) in struct_type_declaration[1:]}
|
||||
return StructTypeDeclaration(name=name, fields=fields)
|
||||
|
||||
def type_declaration(self, type_declaration):
|
||||
(type_declaration,) = type_declaration
|
||||
return type_declaration
|
||||
|
||||
def module(self, module_items) -> Module:
|
||||
functions = []
|
||||
types = []
|
||||
for item in module_items:
|
||||
if isinstance(item, Function):
|
||||
functions.append(item)
|
||||
else:
|
||||
types.append(item)
|
||||
return Module(functions=functions, types=types)
|
||||
|
||||
|
||||
boring_parser = Lark(boring_grammar, start="module", lexer="standard")
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional, Union
|
||||
from typing import List, Dict, Optional, Union, Tuple
|
||||
|
||||
|
||||
from boring import parse, typedefs
|
||||
|
||||
|
||||
Identified = Union[parse.LetStatement, parse.Function, parse.VariableDeclaration]
|
||||
Identified = Union[
|
||||
parse.LetStatement, parse.Function, parse.VariableDeclaration, parse.TypeDeclaration
|
||||
]
|
||||
Environment = Dict[str, Identified]
|
||||
TypeEnvironment = Dict[str, typedefs.TypeDef]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
environment: Environment
|
||||
type_environment: TypeEnvironment
|
||||
current_function: Optional[parse.Function]
|
||||
|
||||
def copy(self):
|
||||
return Context(self.environment.copy(), self.type_environment.copy(), self.current_function)
|
||||
return Context(self.environment.copy(), self.current_function)
|
||||
|
||||
|
||||
def unify(ctx: Context, first, second) -> bool:
|
||||
changed: bool
|
||||
result, changed = type_compare(ctx, first.type, second.type)
|
||||
first.type = result
|
||||
second.type = result
|
||||
return changed
|
||||
|
||||
|
||||
def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
|
||||
def type_compare(ctx: Context, first: parse.TypeUsage, second: parse.TypeUsage) -> Tuple[parse.TypeUsage, bool]:
|
||||
print(first, second)
|
||||
if isinstance(first, parse.UnknownTypeUsage):
|
||||
if not isinstance(second, parse.UnknownTypeUsage):
|
||||
@@ -41,8 +43,18 @@ def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
|
||||
second, parse.DataTypeUsage
|
||||
):
|
||||
assert second == first
|
||||
assert first.name in ctx.type_environment
|
||||
assert second.name in ctx.type_environment
|
||||
assert first.name in ctx.environment # TODO: validate that it is a type
|
||||
assert isinstance(
|
||||
ctx.environment[first.name], parse.StructTypeDeclaration
|
||||
) or isinstance(
|
||||
ctx.environment[first.name], parse.PrimitiveTypeDeclaration
|
||||
)
|
||||
assert second.name in ctx.environment
|
||||
assert isinstance(
|
||||
ctx.environment[second.name], parse.StructTypeDeclaration
|
||||
) or isinstance(
|
||||
ctx.environment[second.name], parse.PrimitiveTypeDeclaration
|
||||
)
|
||||
return first, False
|
||||
elif isinstance(first, parse.FunctionTypeUsage) and isinstance(
|
||||
second, parse.FunctionTypeUsage
|
||||
@@ -64,8 +76,22 @@ def type_compare(ctx: Context, first, second) -> (parse.TypeUsage, bool):
|
||||
assert False, f"mismatched types {first}, {second}"
|
||||
|
||||
|
||||
def assert_exists(ctx: Context, type: parse.TypeUsage):
|
||||
if isinstance(type, parse.DataTypeUsage):
|
||||
assert type.name in ctx.environment
|
||||
elif isinstance(type, parse.FunctionTypeUsage):
|
||||
assert_exists(ctx, type.return_type)
|
||||
for argument in type.arguments:
|
||||
assert_exists(ctx, argument)
|
||||
|
||||
class TypeChecker:
|
||||
def with_module(self, ctx: Context, module: parse.Module) -> bool:
|
||||
for type_declaration in module.types:
|
||||
ctx.environment[type_declaration.name] = type_declaration
|
||||
for type_declaration in module.types:
|
||||
if isinstance(type_declaration, parse.StructTypeDeclaration):
|
||||
for name, field in type_declaration.fields.items():
|
||||
assert_exists(ctx, field)
|
||||
for function in module.functions:
|
||||
ctx.environment[function.name] = function
|
||||
changed = False
|
||||
@@ -83,8 +109,13 @@ class TypeChecker:
|
||||
|
||||
changed = self.with_block(function_ctx, function.block)
|
||||
|
||||
if not (isinstance(function.block.type, parse.DataTypeUsage) and function.block.type.name == parse.NEVER_TYPE):
|
||||
type, compare_changed = type_compare(function_ctx, function.block.type, function.type.return_type)
|
||||
if not (
|
||||
isinstance(function.block.type, parse.DataTypeUsage)
|
||||
and function.block.type.name == parse.NEVER_TYPE
|
||||
):
|
||||
type, compare_changed = type_compare(
|
||||
function_ctx, function.block.type, function.type.return_type
|
||||
)
|
||||
function.block.type = type
|
||||
function.type.return_type = type
|
||||
if compare_changed is True:
|
||||
@@ -105,23 +136,15 @@ class TypeChecker:
|
||||
if isinstance(final, parse.LetStatement):
|
||||
if isinstance(block.type, parse.UnknownTypeUsage):
|
||||
changed = True
|
||||
block.type = parse.DataTypeUsage(
|
||||
name=parse.UNIT_TYPE
|
||||
)
|
||||
block.type = parse.DataTypeUsage(name=parse.UNIT_TYPE)
|
||||
else:
|
||||
assert block.type == parse.DataTypeUsage(
|
||||
name=parse.UNIT_TYPE
|
||||
)
|
||||
assert block.type == parse.DataTypeUsage(name=parse.UNIT_TYPE)
|
||||
elif isinstance(final, parse.ReturnStatement):
|
||||
if isinstance(block.type, parse.UnknownTypeUsage):
|
||||
changed = True
|
||||
block.type = parse.DataTypeUsage(
|
||||
name=parse.NEVER_TYPE
|
||||
)
|
||||
block.type = parse.DataTypeUsage(name=parse.NEVER_TYPE)
|
||||
else:
|
||||
assert block.type == parse.DataTypeUsage(
|
||||
name=parse.NEVER_TYPE
|
||||
)
|
||||
assert block.type == parse.DataTypeUsage(name=parse.NEVER_TYPE)
|
||||
elif isinstance(final, parse.Expression):
|
||||
if unify(block_ctx, final, block):
|
||||
changed = True
|
||||
@@ -156,7 +179,11 @@ class TypeChecker:
|
||||
changed = False
|
||||
if self.with_expression(ctx, assignment_statement.expression):
|
||||
changed = True
|
||||
if unify(ctx, assignment_statement, ctx.environment[assignment_statement.variable_name]):
|
||||
if unify(
|
||||
ctx,
|
||||
assignment_statement,
|
||||
ctx.environment[assignment_statement.variable_name],
|
||||
):
|
||||
changed = True
|
||||
if unify(ctx, assignment_statement, assignment_statement.expression):
|
||||
changed = True
|
||||
@@ -168,8 +195,15 @@ class TypeChecker:
|
||||
changed = self.with_expression(ctx, return_statement.source)
|
||||
|
||||
# Doesn't match on an unreachable return
|
||||
if not (isinstance(return_statement.source.type, parse.DataTypeUsage) and return_statement.source.type.name == parse.NEVER_TYPE):
|
||||
type, compare_changed = type_compare(ctx, return_statement.source.type, ctx.current_function.type.return_type)
|
||||
if not (
|
||||
isinstance(return_statement.source.type, parse.DataTypeUsage)
|
||||
and return_statement.source.type.name == parse.NEVER_TYPE
|
||||
):
|
||||
assert isinstance(ctx.current_function, parse.Function)
|
||||
assert isinstance(ctx.current_function.type, parse.FunctionTypeUsage)
|
||||
type, compare_changed = type_compare(
|
||||
ctx, return_statement.source.type, ctx.current_function.type.return_type
|
||||
)
|
||||
return_statement.source.type = type
|
||||
ctx.current_function.type.return_type = type
|
||||
if compare_changed is True:
|
||||
@@ -245,6 +279,7 @@ class TypeChecker:
|
||||
if self.with_expression(ctx, argument):
|
||||
changed = True
|
||||
|
||||
assert isinstance(function_call.source.type, parse.FunctionTypeUsage)
|
||||
return_type, return_changed = type_compare(
|
||||
ctx, function_call.type, function_call.source.type.return_type
|
||||
)
|
||||
@@ -265,14 +300,18 @@ class TypeChecker:
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def with_literal_float(self, ctx: Context, literal_float: parse.LiteralFloat) -> bool:
|
||||
def with_literal_float(
|
||||
self, ctx: Context, literal_float: parse.LiteralFloat
|
||||
) -> bool:
|
||||
floats = ["F32", "F64", "F128"]
|
||||
if not isinstance(literal_float.type, parse.UnknownTypeUsage):
|
||||
assert isinstance(literal_float.type, parse.DataTypeUsage)
|
||||
assert literal_float.type.name in floats, f"{literal_float.type}"
|
||||
return False
|
||||
|
||||
def with_literal_int(self, ctx: Context, literal_int: parse.LiteralInt) -> bool:
|
||||
ints = ["I8", "I16", "I32", "I64", "I128", "U8", "U16", "U32", "U64", "U128"]
|
||||
if not isinstance(literal_int.type, parse.UnknownTypeUsage):
|
||||
assert isinstance(literal_int.type, parse.DataTypeUsage)
|
||||
assert literal_int.type.name in ints, f"{literal_int.type}"
|
||||
return False
|
||||
|
||||
@@ -4,22 +4,22 @@ from typing import List, Dict, Optional, Union
|
||||
|
||||
|
||||
class IntBitness(enum.Enum):
|
||||
X8 = 'X8'
|
||||
X16 = 'X16'
|
||||
X32 = 'X32'
|
||||
X64 = 'X64'
|
||||
X128 = 'X128'
|
||||
X8 = "X8"
|
||||
X16 = "X16"
|
||||
X32 = "X32"
|
||||
X64 = "X64"
|
||||
X128 = "X128"
|
||||
|
||||
|
||||
class Signedness(enum.Enum):
|
||||
Signed = 'Signed'
|
||||
Unsigned = 'Unsigned'
|
||||
Signed = "Signed"
|
||||
Unsigned = "Unsigned"
|
||||
|
||||
|
||||
class FloatBitness(enum.Enum):
|
||||
X32 = 'X32'
|
||||
X64 = 'X64'
|
||||
X128 = 'X128'
|
||||
X32 = "X32"
|
||||
X64 = "X64"
|
||||
X128 = "X128"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,6 +39,11 @@ class FunctionTypeDef:
|
||||
return_type: "TypeDef"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructTypeDef:
|
||||
fields: Dict[str, "TypeDef"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnitTypeDef:
|
||||
pass
|
||||
@@ -49,26 +54,25 @@ class NeverTypeDef:
|
||||
pass
|
||||
|
||||
|
||||
TypeDef = Union[IntTypeDef, FloatTypeDef, FunctionTypeDef, UnitTypeDef, NeverTypeDef]
|
||||
TypeDef = Union[
|
||||
IntTypeDef, FloatTypeDef, FunctionTypeDef, StructTypeDef, UnitTypeDef, NeverTypeDef
|
||||
]
|
||||
|
||||
|
||||
builtins: Dict[str, TypeDef] = {
|
||||
'U8': IntTypeDef(Signedness.Unsigned, IntBitness.X8),
|
||||
'U16': IntTypeDef(Signedness.Unsigned, IntBitness.X16),
|
||||
'U32': IntTypeDef(Signedness.Unsigned, IntBitness.X32),
|
||||
'U64': IntTypeDef(Signedness.Unsigned, IntBitness.X64),
|
||||
'U128': IntTypeDef(Signedness.Unsigned, IntBitness.X128),
|
||||
|
||||
'I8': IntTypeDef(Signedness.Signed, IntBitness.X8),
|
||||
'I16': IntTypeDef(Signedness.Signed, IntBitness.X16),
|
||||
'I32': IntTypeDef(Signedness.Signed, IntBitness.X32),
|
||||
'I64': IntTypeDef(Signedness.Signed, IntBitness.X64),
|
||||
'I128': IntTypeDef(Signedness.Signed, IntBitness.X128),
|
||||
|
||||
'F32': FloatTypeDef(FloatBitness.X32),
|
||||
'F64': FloatTypeDef(FloatBitness.X64),
|
||||
'F128': FloatTypeDef(FloatBitness.X128),
|
||||
|
||||
'()': UnitTypeDef(),
|
||||
'!': NeverTypeDef(),
|
||||
"U8": IntTypeDef(Signedness.Unsigned, IntBitness.X8),
|
||||
"U16": IntTypeDef(Signedness.Unsigned, IntBitness.X16),
|
||||
"U32": IntTypeDef(Signedness.Unsigned, IntBitness.X32),
|
||||
"U64": IntTypeDef(Signedness.Unsigned, IntBitness.X64),
|
||||
"U128": IntTypeDef(Signedness.Unsigned, IntBitness.X128),
|
||||
"I8": IntTypeDef(Signedness.Signed, IntBitness.X8),
|
||||
"I16": IntTypeDef(Signedness.Signed, IntBitness.X16),
|
||||
"I32": IntTypeDef(Signedness.Signed, IntBitness.X32),
|
||||
"I64": IntTypeDef(Signedness.Signed, IntBitness.X64),
|
||||
"I128": IntTypeDef(Signedness.Signed, IntBitness.X128),
|
||||
"F32": FloatTypeDef(FloatBitness.X32),
|
||||
"F64": FloatTypeDef(FloatBitness.X64),
|
||||
"F128": FloatTypeDef(FloatBitness.X128),
|
||||
"()": UnitTypeDef(),
|
||||
"!": NeverTypeDef(),
|
||||
}
|
||||
|
||||
@@ -31,3 +31,8 @@ fn unit_function() {
|
||||
fn main(): I32 {
|
||||
add(4, subtract(5, 2))
|
||||
}
|
||||
|
||||
|
||||
type User struct {
|
||||
id: U64,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user