added struct getters

This commit is contained in:
Andrew Segavac
2021-06-12 12:26:53 -06:00
parent 421a5160fd
commit 5464179883
4 changed files with 59 additions and 12 deletions

View File

@@ -31,9 +31,7 @@ NEVER_TYPE = "!"
@dataclass
class FunctionTypeUsage:
arguments: List[
"TypeUsage"
]
arguments: List["TypeUsage"]
return_type: "TypeUsage"
@@ -68,6 +66,7 @@ class LiteralFloat:
value: float
type: TypeUsage
@dataclass
class LiteralStruct:
name: str
@@ -82,6 +81,13 @@ class FunctionCall:
type: TypeUsage
@dataclass
class StructGetter:
source: "Expression"
attribute: str
type: TypeUsage
@dataclass
class Operation:
left: "Expression"
@@ -109,6 +115,7 @@ class Expression:
LiteralFloat,
LiteralStruct,
FunctionCall,
StructGetter,
"Block",
ReturnStatement,
VariableUsage,
@@ -190,6 +197,8 @@ boring_grammar = r"""
function_call : expression "(" [expression ("," expression)*] ")"
struct_getter : expression "." identifier
add_expression : expression plus factor
sub_expression : expression minus factor
mult_expression : expression mult term
@@ -212,6 +221,7 @@ boring_grammar = r"""
| literal_struct
| variable_usage
| function_call
| struct_getter
| "(" expression ")"
| block
@@ -317,6 +327,10 @@ class TreeToBoring(Transformer):
def function_call(self, call) -> FunctionCall:
return FunctionCall(source=call[0], arguments=call[1:], type=UnknownTypeUsage())
def struct_getter(self, getter) -> StructGetter:
expression, attribute = getter
return StructGetter(expression, attribute, UnknownTypeUsage())
def add_expression(self, ae) -> Operation:
return Operation(left=ae[0], op=ae[1], right=ae[2], type=UnknownTypeUsage())

View File

@@ -28,7 +28,9 @@ def unify(ctx: Context, first, second) -> bool:
return changed
def type_compare(ctx: Context, first: parse.TypeUsage, second: parse.TypeUsage) -> Tuple[parse.TypeUsage, bool]:
def type_compare(
ctx: Context, first: parse.TypeUsage, second: parse.TypeUsage
) -> Tuple[parse.TypeUsage, bool]:
print(first, second)
if isinstance(first, parse.UnknownTypeUsage):
if not isinstance(second, parse.UnknownTypeUsage):
@@ -84,6 +86,7 @@ def assert_exists(ctx: Context, type: parse.TypeUsage):
for argument in type.arguments:
assert_exists(ctx, argument)
class TypeChecker:
def with_module(self, ctx: Context, module: parse.Module) -> bool:
for type_declaration in module.types:
@@ -234,6 +237,11 @@ class TypeChecker:
if unify(ctx, subexpression, expression):
changed = True
return changed
if isinstance(subexpression, parse.StructGetter):
changed = self.with_struct_getter(ctx, subexpression)
if unify(ctx, subexpression, expression):
changed = True
return changed
if isinstance(subexpression, parse.Block):
changed = self.with_block(ctx, subexpression)
if unify(ctx, subexpression, expression):
@@ -305,6 +313,22 @@ class TypeChecker:
changed = True
return changed
def with_struct_getter(
self, ctx: Context, struct_getter: parse.StructGetter
) -> bool:
changed = self.with_expression(ctx, struct_getter.source)
assert isinstance(struct_getter.source.type, parse.DataTypeUsage)
struct_declaration = ctx.environment[struct_getter.source.type.name]
assert isinstance(struct_declaration, parse.StructTypeDeclaration)
assert struct_getter.attribute in struct_declaration.fields
result_type, changed_getter = type_compare(
ctx, struct_getter.type, struct_declaration.fields[struct_getter.attribute]
)
if changed_getter:
changed = True
struct_getter.type = result_type
return changed
def with_literal_float(
self, ctx: Context, literal_float: parse.LiteralFloat
) -> bool:
@@ -321,7 +345,9 @@ class TypeChecker:
assert literal_int.type.name in ints, f"{literal_int.type}"
return False
def with_literal_struct(self, ctx: Context, literal_struct: parse.LiteralStruct) -> bool:
def with_literal_struct(
self, ctx: Context, literal_struct: parse.LiteralStruct
) -> bool:
assert literal_struct.name in ctx.environment
struct_declaration = ctx.environment[literal_struct.name]
assert isinstance(struct_declaration, parse.StructTypeDeclaration)
@@ -330,7 +356,9 @@ class TypeChecker:
assert name in literal_struct.fields
if self.with_expression(ctx, literal_struct.fields[name]):
changed = True
result_type, field_changed = type_compare(ctx, field_type, literal_struct.fields[name].type)
result_type, field_changed = type_compare(
ctx, field_type, literal_struct.fields[name].type
)
if field_changed:
literal_struct.fields[name].type = result_type
changed = True

View File

@@ -28,17 +28,22 @@ fn unit_function() {
let a: I32 = 4;
}
fn main(): I32 {
add(4, subtract(5, 2))
}
fn returns_user(): User {
return User{
id: 4,
};
}
fn main(): I32 {
add(4, subtract(5, 2))
fn get_user_id(): U64 {
let user = returns_user();
return user.id;
}
type User struct {
id: U64,
}

View File

@@ -40,12 +40,12 @@ TODO:
* Structs
* ~Define~
* ~Literal~
* Getter
* ~Getter~
* Setter
* Generics
* Enums
* Methods
* Traits
* Generics
* Enums
* Arrays
* Strings
* Lambdas