diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..25acc7b --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1 @@ +max_width=140 # Not ideal diff --git a/Dockerfile b/Dockerfile index 76040e5..3af2e60 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,5 @@ FROM rust:1.54 RUN apt update && apt-get install -y llvm clang +RUN rustup component add rustfmt WORKDIR /code diff --git a/README.md b/README.md index 2f86ef2..4f20fec 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ This language is under active development, progress will be marked here as the l - [x] Declaration - [x] Use - [ ] Traits + - [x] Basic + - [ ] Default Functions - [ ] Generics - [ ] Basic - [ ] Higher kinded types diff --git a/examples/main.bl b/examples/main.bl index 90819c4..ce03029 100644 --- a/examples/main.bl +++ b/examples/main.bl @@ -81,19 +81,19 @@ fn main(): i64 { } } -// type TestTrait trait { -// fn class_method(id: i64): Self; -// fn instance_method(self: Self): i64; -// fn default_impl(self: Self): i64 { -// return self.instance_method(); -// } -// } -// -// impl TestTrait for User { -// fn class_method(id: i64): Self { -// return User{id: id,}; -// } -// fn instance_method(self: Self): i64 { -// return self.get_id(); -// } -// } +type TestTrait trait { + fn class_method(id: i64): Self; + fn instance_method(self: Self): i64; + fn default_impl(self: Self): i64 { + return self.instance_method(); + } +} + +impl TestTrait for User { + fn class_method(id: i64): Self { + return User{id: id,}; + } + fn instance_method(self: Self): i64 { + return self.get_id(); + } +} diff --git a/src/ast.rs b/src/ast.rs index f1e5611..ef25d75 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -6,9 +6,7 @@ pub struct IdGenerator { impl IdGenerator { pub fn new() -> Self { - IdGenerator { - counter: RefCell::new(0), - } + IdGenerator { counter: RefCell::new(0) } } pub fn next(&self) -> String { @@ -76,15 +74,11 @@ pub enum TypeUsage { impl TypeUsage { pub fn new_unknown(id_gen: &IdGenerator) -> TypeUsage { - return TypeUsage::Unknown(UnknownTypeUsage { - name: id_gen.next(), - }); + return TypeUsage::Unknown(UnknownTypeUsage { name: id_gen.next() }); } pub fn new_named(identifier: Identifier) -> TypeUsage { - return TypeUsage::Named(NamedTypeUsage { - name: identifier.clone(), - }); + return TypeUsage::Named(NamedTypeUsage { name: identifier.clone() }); } pub fn new_builtin(name: String) -> TypeUsage { @@ -100,9 +94,7 @@ impl TypeUsage { pub fn new_function(arg_count: usize, id_gen: &IdGenerator) -> TypeUsage { return TypeUsage::Function(FunctionTypeUsage { - arguments: (0..arg_count) - .map(|_| TypeUsage::new_unknown(&id_gen)) - .collect(), + arguments: (0..arg_count).map(|_| TypeUsage::new_unknown(&id_gen)).collect(), return_type: Box::new(TypeUsage::new_unknown(&id_gen)), }); } @@ -252,6 +244,25 @@ pub struct FunctionDeclaration { pub return_type: TypeUsage, } +impl FunctionDeclaration { + pub fn to_type(&self) -> TypeUsage { + TypeUsage::Function(FunctionTypeUsage { + arguments: self.arguments.iter().map(|arg| arg.type_.clone()).collect(), + return_type: Box::new(self.return_type.clone()), + }) + } + + pub fn to_method_type(&self) -> TypeUsage { + TypeUsage::Function(FunctionTypeUsage { + arguments: self.arguments[1..self.arguments.len()] + .iter() + .map(|arg| arg.type_.clone()) + .collect(), + return_type: Box::new(self.return_type.clone()), + }) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Function { pub declaration: FunctionDeclaration, @@ -275,6 +286,18 @@ pub struct StructTypeDeclaration { pub fields: Vec, } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum TraitItem { + FunctionDeclaration(FunctionDeclaration), + Function(Function), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TraitTypeDeclaration { + pub name: Identifier, + pub functions: Vec, +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct AliasTypeDeclaration { pub name: Identifier, @@ -286,10 +309,12 @@ pub enum TypeDeclaration { Struct(StructTypeDeclaration), Primitive(PrimitiveTypeDeclaration), Alias(AliasTypeDeclaration), + Trait(TraitTypeDeclaration), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Impl { + pub trait_: Option, pub struct_name: Identifier, pub functions: Vec, } diff --git a/src/compiler.rs b/src/compiler.rs index 56ab103..48fc3e0 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -28,28 +28,18 @@ impl<'ctx> ModuleCodeGen<'ctx> { } pub fn gen_literal_int(&mut self, literal_int: &ast::LiteralInt) -> IntValue<'ctx> { - self.context.i64_type().const_int( - unsafe { mem::transmute::(literal_int.value) }, - true, - ) + self.context + .i64_type() + .const_int(unsafe { mem::transmute::(literal_int.value) }, true) } - pub fn gen_op_expression( - &mut self, - scope: &Scope<'ctx>, - operation: &ast::Operation, - ) -> IntValue<'ctx> { + pub fn gen_op_expression(&mut self, scope: &Scope<'ctx>, operation: &ast::Operation) -> IntValue<'ctx> { let lhs_result = self.gen_expression(scope, &operation.left); let rhs_result = self.gen_expression(scope, &operation.right); self.gen_op_int(&lhs_result, &rhs_result, &operation.op) } - pub fn gen_op_int( - &mut self, - lhs: &IntValue<'ctx>, - rhs: &IntValue<'ctx>, - op: &ast::Operator, - ) -> IntValue<'ctx> { + pub fn gen_op_int(&mut self, lhs: &IntValue<'ctx>, rhs: &IntValue<'ctx>, op: &ast::Operator) -> IntValue<'ctx> { match *op { ast::Operator::Plus => self.builder.build_int_add(*lhs, *rhs, "add"), ast::Operator::Minus => self.builder.build_int_sub(*lhs, *rhs, "sub"), @@ -58,35 +48,23 @@ impl<'ctx> ModuleCodeGen<'ctx> { } } - pub fn gen_expression( - &mut self, - scope: &Scope<'ctx>, - expression: &Box, - ) -> IntValue<'ctx> { + pub fn gen_expression(&mut self, scope: &Scope<'ctx>, expression: &Box) -> IntValue<'ctx> { match &**expression { ast::Expression::LiteralInt(literal_int) => self.gen_literal_int(&literal_int), ast::Expression::Identifier(identifier) => match scope[&identifier.name] { BasicValueEnum::IntValue(value) => value, _ => panic!("function returned type other than int, no types yet"), }, - ast::Expression::FunctionCall(function_call) => { - self.gen_function_call(scope, &function_call) - } + ast::Expression::FunctionCall(function_call) => self.gen_function_call(scope, &function_call), ast::Expression::Op(operation) => self.gen_op_expression(scope, &operation), } } - pub fn gen_function_call( - &mut self, - scope: &Scope<'ctx>, - function_call: &ast::FunctionCall, - ) -> IntValue<'ctx> { + pub fn gen_function_call(&mut self, scope: &Scope<'ctx>, function_call: &ast::FunctionCall) -> IntValue<'ctx> { let fn_value = self.module.get_function(&function_call.name.name).unwrap(); let mut arguments = Vec::new(); for expression in (&function_call.arguments).into_iter() { - arguments.push(BasicValueEnum::IntValue( - self.gen_expression(scope, &expression), - )); + arguments.push(BasicValueEnum::IntValue(self.gen_expression(scope, &expression))); } let result = self @@ -123,10 +101,7 @@ impl<'ctx> ModuleCodeGen<'ctx> { let mut scope = self.scope.clone(); for (i, param) in (&function.arguments).into_iter().enumerate() { - scope.insert( - param.name.name.to_string(), - fn_value.get_nth_param(i.try_into().unwrap()).unwrap(), - ); + scope.insert(param.name.name.to_string(), fn_value.get_nth_param(i.try_into().unwrap()).unwrap()); } let body = &function.block; let return_value = self.gen_expression(&scope, &body.expression); diff --git a/src/errors.rs b/src/errors.rs index de15350..3b8c90f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -18,6 +18,10 @@ pub enum TypingError { }, #[error("unknown field name")] UnknownFieldName { identifier: ast::Identifier }, + #[error("cannot assign to method")] + CannotAssignToMethod { identifier: ast::Identifier }, + #[error("multiple field name matches")] + MultipleFieldName { identifier: ast::Identifier }, #[error("attribute gotten of non-struct")] AttributeOfNonstruct { identifier: ast::Identifier }, #[error("name is not a struct, cannot instaniate")] @@ -27,6 +31,15 @@ pub enum TypingError { struct_name: ast::Identifier, struct_definition_name: ast::Identifier, }, + #[error("missing trait function")] + MissingTraitFunction { + struct_name: ast::Identifier, + function_name: ast::Identifier, + }, + #[error("function not in trait")] + FunctionNotInTrait { function_name: ast::Identifier }, + #[error("impl trait must be trait")] + ImplTraitMustBeTrait { trait_name: ast::Identifier }, #[error("function call used with non-function")] FunctionCallNotAFunction { // TODO: add position diff --git a/src/grammar.lalrpop b/src/grammar.lalrpop index 0ae60e6..1f54e79 100644 --- a/src/grammar.lalrpop +++ b/src/grammar.lalrpop @@ -25,7 +25,9 @@ match { "if", "else", "=", + "for", "type", + "trait", "struct", "impl", ",", @@ -208,13 +210,24 @@ pub AliasTypeDeclaration: ast::AliasTypeDeclaration = { "type" "=" ";" => ast::AliasTypeDeclaration{name: i, replaces: t} }; +pub TraitItem: ast::TraitItem = { + ";" => ast::TraitItem::FunctionDeclaration(fd), + => ast::TraitItem::Function(f), +}; + +pub TraitTypeDeclaration: ast::TraitTypeDeclaration = { + "type" "trait" "{" "}" => ast::TraitTypeDeclaration{name: i, functions: ti}, +}; + pub TypeDeclaration: ast::TypeDeclaration = { => ast::TypeDeclaration::Struct(s), => ast::TypeDeclaration::Alias(a), + => ast::TypeDeclaration::Trait(t), }; pub Impl: ast::Impl = { - "impl" "{" "}" => ast::Impl{struct_name: i, functions: f} + "impl" "{" "}" => ast::Impl{trait_: None, struct_name: i, functions: f}, + "impl" "for" "{" "}" => ast::Impl{trait_: Some(t), struct_name: i, functions: f}, }; pub ModuleItem: ast::ModuleItem = { diff --git a/src/interpreter.rs b/src/interpreter.rs index da90165..8c92b2c 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -58,8 +58,7 @@ struct Context { impl Context { fn set_variable(&mut self, name: String, value: &Value) { - self.environment - .insert(name.to_string(), NamedEntity::Variable(value.clone())); + self.environment.insert(name.to_string(), NamedEntity::Variable(value.clone())); } fn new_env(&self) -> Context { @@ -94,8 +93,7 @@ impl Context { ); } ast::ModuleItem::Impl(impl_) => { - ctx.impls - .insert(impl_.struct_name.name.value.to_string(), impl_.clone()); + ctx.impls.insert(impl_.struct_name.name.value.to_string(), impl_.clone()); } _ => {} } @@ -108,127 +106,97 @@ fn create_builtins() -> HashMap { let mut result = HashMap::new(); result.insert( "i8".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "i8".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "i8".to_string(), + })), ); result.insert( "i16".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "i16".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "i16".to_string(), + })), ); result.insert( "i32".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "i32".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "i32".to_string(), + })), ); result.insert( "i64".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "i64".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "i64".to_string(), + })), ); result.insert( "isize".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "isize".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "isize".to_string(), + })), ); result.insert( "u8".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "u8".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "u8".to_string(), + })), ); result.insert( "u16".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "u16".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "u16".to_string(), + })), ); result.insert( "u32".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "u32".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "u32".to_string(), + })), ); result.insert( "u64".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "u64".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "u64".to_string(), + })), ); result.insert( "usize".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "usize".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "usize".to_string(), + })), ); result.insert( "f32".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "f32".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "f32".to_string(), + })), ); result.insert( "f64".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "f64".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "f64".to_string(), + })), ); result.insert( "bool".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "bool".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "bool".to_string(), + })), ); result.insert( "!".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "!".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "!".to_string(), + })), ); result.insert( "unit".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "!".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "!".to_string(), + })), ); return result; @@ -277,11 +245,7 @@ impl TreeWalkInterpreter { return last; } - fn with_statement( - self: &Self, - ctx: &mut Context, - statement: &ast::Statement, - ) -> ExpressionResult { + fn with_statement(self: &Self, ctx: &mut Context, statement: &ast::Statement) -> ExpressionResult { match statement { ast::Statement::Return(return_statement) => { let result = match self.with_expression(ctx, &return_statement.source) { @@ -311,11 +275,7 @@ impl TreeWalkInterpreter { } } - fn with_assignment_statement( - self: &Self, - ctx: &mut Context, - statement: &ast::AssignmentStatement, - ) -> ExpressionResult { + fn with_assignment_statement(self: &Self, ctx: &mut Context, statement: &ast::AssignmentStatement) -> ExpressionResult { let result = match self.with_expression(ctx, &statement.expression) { ExpressionResult::Value(r) => r, ExpressionResult::Return(r) => { @@ -336,9 +296,7 @@ impl TreeWalkInterpreter { match &mut source { Value::Struct(s) => { let mut struct_ = s.lock().unwrap(); - struct_ - .fields - .insert(struct_attr.attribute.name.value.clone(), result); + struct_.fields.insert(struct_attr.attribute.name.value.clone(), result); } _ => panic!("set attr on nonstruct, should never happen due to type system"), } @@ -347,11 +305,7 @@ impl TreeWalkInterpreter { return ExpressionResult::Value(Value::Unit); } - fn with_expression( - self: &Self, - ctx: &mut Context, - expression: &ast::Expression, - ) -> ExpressionResult { + fn with_expression(self: &Self, ctx: &mut Context, expression: &ast::Expression) -> ExpressionResult { match &*expression.subexpression { ast::Subexpression::LiteralInt(literal_int) => { let value: i64 = literal_int.value.value.parse().unwrap(); @@ -362,18 +316,12 @@ impl TreeWalkInterpreter { return ExpressionResult::Value(Value::Numeric(NumericValue::F64(value))); } ast::Subexpression::LiteralBool(literal_bool) => { - let value: bool = if &literal_bool.value.value == "true" { - true - } else { - false - }; + let value: bool = if &literal_bool.value.value == "true" { true } else { false }; return ExpressionResult::Value(Value::Bool(value)); } ast::Subexpression::LiteralStruct(literal_struct) => { let declaration = match &ctx.environment[&literal_struct.name.name.value] { - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct(declaration)) => { - declaration.clone() - } + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct(declaration)) => declaration.clone(), _ => panic!("not a struct"), }; @@ -416,17 +364,10 @@ impl TreeWalkInterpreter { match &source { Value::Function(Function::User(user_function)) => { let mut fn_ctx = ctx.new_env(); - for (i, function_arg) in - user_function.declaration.arguments.iter().enumerate() - { - fn_ctx.set_variable( - function_arg.name.name.value.to_string(), - &argument_values[i].clone(), - ); + for (i, function_arg) in user_function.declaration.arguments.iter().enumerate() { + fn_ctx.set_variable(function_arg.name.name.value.to_string(), &argument_values[i].clone()); } - return ExpressionResult::Value( - self.with_function(&mut fn_ctx, user_function), - ); + return ExpressionResult::Value(self.with_function(&mut fn_ctx, user_function)); } Value::Function(Function::Builtin(builtin_function)) => { return ExpressionResult::Value(builtin_function(argument_values)); @@ -473,9 +414,7 @@ impl TreeWalkInterpreter { match &source { Value::Struct(struct_) => { let s = struct_.lock().unwrap(); - return ExpressionResult::Value( - s.fields[&struct_getter.attribute.name.value].clone(), - ); + return ExpressionResult::Value(s.fields[&struct_getter.attribute.name.value].clone()); } _ => { panic!("TypeError: struct getter used with non-struct"); @@ -500,217 +439,137 @@ impl TreeWalkInterpreter { }; let result = match (&left, &op.op, &right) { //I - ( - Value::Numeric(NumericValue::I8(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::I8(r)), - ) => Value::Numeric(NumericValue::I8(l + r)), - ( - Value::Numeric(NumericValue::I8(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::I8(r)), - ) => Value::Numeric(NumericValue::I8(l - r)), - ( - Value::Numeric(NumericValue::I8(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::I8(r)), - ) => Value::Numeric(NumericValue::I8(l * r)), - ( - Value::Numeric(NumericValue::I8(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::I8(r)), - ) => Value::Numeric(NumericValue::I8(l / r)), + (Value::Numeric(NumericValue::I8(l)), ast::Operator::Plus, Value::Numeric(NumericValue::I8(r))) => { + Value::Numeric(NumericValue::I8(l + r)) + } + (Value::Numeric(NumericValue::I8(l)), ast::Operator::Minus, Value::Numeric(NumericValue::I8(r))) => { + Value::Numeric(NumericValue::I8(l - r)) + } + (Value::Numeric(NumericValue::I8(l)), ast::Operator::Mul, Value::Numeric(NumericValue::I8(r))) => { + Value::Numeric(NumericValue::I8(l * r)) + } + (Value::Numeric(NumericValue::I8(l)), ast::Operator::Div, Value::Numeric(NumericValue::I8(r))) => { + Value::Numeric(NumericValue::I8(l / r)) + } - ( - Value::Numeric(NumericValue::I16(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::I16(r)), - ) => Value::Numeric(NumericValue::I16(l + r)), - ( - Value::Numeric(NumericValue::I16(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::I16(r)), - ) => Value::Numeric(NumericValue::I16(l - r)), - ( - Value::Numeric(NumericValue::I16(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::I16(r)), - ) => Value::Numeric(NumericValue::I16(l * r)), - ( - Value::Numeric(NumericValue::I16(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::I16(r)), - ) => Value::Numeric(NumericValue::I16(l / r)), + (Value::Numeric(NumericValue::I16(l)), ast::Operator::Plus, Value::Numeric(NumericValue::I16(r))) => { + Value::Numeric(NumericValue::I16(l + r)) + } + (Value::Numeric(NumericValue::I16(l)), ast::Operator::Minus, Value::Numeric(NumericValue::I16(r))) => { + Value::Numeric(NumericValue::I16(l - r)) + } + (Value::Numeric(NumericValue::I16(l)), ast::Operator::Mul, Value::Numeric(NumericValue::I16(r))) => { + Value::Numeric(NumericValue::I16(l * r)) + } + (Value::Numeric(NumericValue::I16(l)), ast::Operator::Div, Value::Numeric(NumericValue::I16(r))) => { + Value::Numeric(NumericValue::I16(l / r)) + } - ( - Value::Numeric(NumericValue::I32(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::I32(r)), - ) => Value::Numeric(NumericValue::I32(l + r)), - ( - Value::Numeric(NumericValue::I32(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::I32(r)), - ) => Value::Numeric(NumericValue::I32(l - r)), - ( - Value::Numeric(NumericValue::I32(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::I32(r)), - ) => Value::Numeric(NumericValue::I32(l * r)), - ( - Value::Numeric(NumericValue::I32(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::I32(r)), - ) => Value::Numeric(NumericValue::I32(l / r)), + (Value::Numeric(NumericValue::I32(l)), ast::Operator::Plus, Value::Numeric(NumericValue::I32(r))) => { + Value::Numeric(NumericValue::I32(l + r)) + } + (Value::Numeric(NumericValue::I32(l)), ast::Operator::Minus, Value::Numeric(NumericValue::I32(r))) => { + Value::Numeric(NumericValue::I32(l - r)) + } + (Value::Numeric(NumericValue::I32(l)), ast::Operator::Mul, Value::Numeric(NumericValue::I32(r))) => { + Value::Numeric(NumericValue::I32(l * r)) + } + (Value::Numeric(NumericValue::I32(l)), ast::Operator::Div, Value::Numeric(NumericValue::I32(r))) => { + Value::Numeric(NumericValue::I32(l / r)) + } - ( - Value::Numeric(NumericValue::I64(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::I64(r)), - ) => Value::Numeric(NumericValue::I64(l + r)), - ( - Value::Numeric(NumericValue::I64(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::I64(r)), - ) => Value::Numeric(NumericValue::I64(l - r)), - ( - Value::Numeric(NumericValue::I64(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::I64(r)), - ) => Value::Numeric(NumericValue::I64(l * r)), - ( - Value::Numeric(NumericValue::I64(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::I64(r)), - ) => Value::Numeric(NumericValue::I64(l / r)), + (Value::Numeric(NumericValue::I64(l)), ast::Operator::Plus, Value::Numeric(NumericValue::I64(r))) => { + Value::Numeric(NumericValue::I64(l + r)) + } + (Value::Numeric(NumericValue::I64(l)), ast::Operator::Minus, Value::Numeric(NumericValue::I64(r))) => { + Value::Numeric(NumericValue::I64(l - r)) + } + (Value::Numeric(NumericValue::I64(l)), ast::Operator::Mul, Value::Numeric(NumericValue::I64(r))) => { + Value::Numeric(NumericValue::I64(l * r)) + } + (Value::Numeric(NumericValue::I64(l)), ast::Operator::Div, Value::Numeric(NumericValue::I64(r))) => { + Value::Numeric(NumericValue::I64(l / r)) + } //U - ( - Value::Numeric(NumericValue::U8(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::U8(r)), - ) => Value::Numeric(NumericValue::U8(l + r)), - ( - Value::Numeric(NumericValue::U8(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::U8(r)), - ) => Value::Numeric(NumericValue::U8(l - r)), - ( - Value::Numeric(NumericValue::U8(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::U8(r)), - ) => Value::Numeric(NumericValue::U8(l * r)), - ( - Value::Numeric(NumericValue::U8(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::U8(r)), - ) => Value::Numeric(NumericValue::U8(l / r)), + (Value::Numeric(NumericValue::U8(l)), ast::Operator::Plus, Value::Numeric(NumericValue::U8(r))) => { + Value::Numeric(NumericValue::U8(l + r)) + } + (Value::Numeric(NumericValue::U8(l)), ast::Operator::Minus, Value::Numeric(NumericValue::U8(r))) => { + Value::Numeric(NumericValue::U8(l - r)) + } + (Value::Numeric(NumericValue::U8(l)), ast::Operator::Mul, Value::Numeric(NumericValue::U8(r))) => { + Value::Numeric(NumericValue::U8(l * r)) + } + (Value::Numeric(NumericValue::U8(l)), ast::Operator::Div, Value::Numeric(NumericValue::U8(r))) => { + Value::Numeric(NumericValue::U8(l / r)) + } - ( - Value::Numeric(NumericValue::U16(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::U16(r)), - ) => Value::Numeric(NumericValue::U16(l + r)), - ( - Value::Numeric(NumericValue::U16(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::U16(r)), - ) => Value::Numeric(NumericValue::U16(l - r)), - ( - Value::Numeric(NumericValue::U16(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::U16(r)), - ) => Value::Numeric(NumericValue::U16(l * r)), - ( - Value::Numeric(NumericValue::U16(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::U16(r)), - ) => Value::Numeric(NumericValue::U16(l / r)), + (Value::Numeric(NumericValue::U16(l)), ast::Operator::Plus, Value::Numeric(NumericValue::U16(r))) => { + Value::Numeric(NumericValue::U16(l + r)) + } + (Value::Numeric(NumericValue::U16(l)), ast::Operator::Minus, Value::Numeric(NumericValue::U16(r))) => { + Value::Numeric(NumericValue::U16(l - r)) + } + (Value::Numeric(NumericValue::U16(l)), ast::Operator::Mul, Value::Numeric(NumericValue::U16(r))) => { + Value::Numeric(NumericValue::U16(l * r)) + } + (Value::Numeric(NumericValue::U16(l)), ast::Operator::Div, Value::Numeric(NumericValue::U16(r))) => { + Value::Numeric(NumericValue::U16(l / r)) + } - ( - Value::Numeric(NumericValue::U32(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::U32(r)), - ) => Value::Numeric(NumericValue::U32(l + r)), - ( - Value::Numeric(NumericValue::U32(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::U32(r)), - ) => Value::Numeric(NumericValue::U32(l - r)), - ( - Value::Numeric(NumericValue::U32(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::U32(r)), - ) => Value::Numeric(NumericValue::U32(l * r)), - ( - Value::Numeric(NumericValue::U32(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::U32(r)), - ) => Value::Numeric(NumericValue::U32(l / r)), + (Value::Numeric(NumericValue::U32(l)), ast::Operator::Plus, Value::Numeric(NumericValue::U32(r))) => { + Value::Numeric(NumericValue::U32(l + r)) + } + (Value::Numeric(NumericValue::U32(l)), ast::Operator::Minus, Value::Numeric(NumericValue::U32(r))) => { + Value::Numeric(NumericValue::U32(l - r)) + } + (Value::Numeric(NumericValue::U32(l)), ast::Operator::Mul, Value::Numeric(NumericValue::U32(r))) => { + Value::Numeric(NumericValue::U32(l * r)) + } + (Value::Numeric(NumericValue::U32(l)), ast::Operator::Div, Value::Numeric(NumericValue::U32(r))) => { + Value::Numeric(NumericValue::U32(l / r)) + } - ( - Value::Numeric(NumericValue::U64(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::U64(r)), - ) => Value::Numeric(NumericValue::U64(l + r)), - ( - Value::Numeric(NumericValue::U64(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::U64(r)), - ) => Value::Numeric(NumericValue::U64(l - r)), - ( - Value::Numeric(NumericValue::U64(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::U64(r)), - ) => Value::Numeric(NumericValue::U64(l * r)), - ( - Value::Numeric(NumericValue::U64(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::U64(r)), - ) => Value::Numeric(NumericValue::U64(l / r)), + (Value::Numeric(NumericValue::U64(l)), ast::Operator::Plus, Value::Numeric(NumericValue::U64(r))) => { + Value::Numeric(NumericValue::U64(l + r)) + } + (Value::Numeric(NumericValue::U64(l)), ast::Operator::Minus, Value::Numeric(NumericValue::U64(r))) => { + Value::Numeric(NumericValue::U64(l - r)) + } + (Value::Numeric(NumericValue::U64(l)), ast::Operator::Mul, Value::Numeric(NumericValue::U64(r))) => { + Value::Numeric(NumericValue::U64(l * r)) + } + (Value::Numeric(NumericValue::U64(l)), ast::Operator::Div, Value::Numeric(NumericValue::U64(r))) => { + Value::Numeric(NumericValue::U64(l / r)) + } //F - ( - Value::Numeric(NumericValue::F32(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::F32(r)), - ) => Value::Numeric(NumericValue::F32(l + r)), - ( - Value::Numeric(NumericValue::F32(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::F32(r)), - ) => Value::Numeric(NumericValue::F32(l - r)), - ( - Value::Numeric(NumericValue::F32(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::F32(r)), - ) => Value::Numeric(NumericValue::F32(l * r)), - ( - Value::Numeric(NumericValue::F32(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::F32(r)), - ) => Value::Numeric(NumericValue::F32(l / r)), + (Value::Numeric(NumericValue::F32(l)), ast::Operator::Plus, Value::Numeric(NumericValue::F32(r))) => { + Value::Numeric(NumericValue::F32(l + r)) + } + (Value::Numeric(NumericValue::F32(l)), ast::Operator::Minus, Value::Numeric(NumericValue::F32(r))) => { + Value::Numeric(NumericValue::F32(l - r)) + } + (Value::Numeric(NumericValue::F32(l)), ast::Operator::Mul, Value::Numeric(NumericValue::F32(r))) => { + Value::Numeric(NumericValue::F32(l * r)) + } + (Value::Numeric(NumericValue::F32(l)), ast::Operator::Div, Value::Numeric(NumericValue::F32(r))) => { + Value::Numeric(NumericValue::F32(l / r)) + } - ( - Value::Numeric(NumericValue::F64(l)), - ast::Operator::Plus, - Value::Numeric(NumericValue::F64(r)), - ) => Value::Numeric(NumericValue::F64(l + r)), - ( - Value::Numeric(NumericValue::F64(l)), - ast::Operator::Minus, - Value::Numeric(NumericValue::F64(r)), - ) => Value::Numeric(NumericValue::F64(l - r)), - ( - Value::Numeric(NumericValue::F64(l)), - ast::Operator::Mul, - Value::Numeric(NumericValue::F64(r)), - ) => Value::Numeric(NumericValue::F64(l * r)), - ( - Value::Numeric(NumericValue::F64(l)), - ast::Operator::Div, - Value::Numeric(NumericValue::F64(r)), - ) => Value::Numeric(NumericValue::F64(l / r)), + (Value::Numeric(NumericValue::F64(l)), ast::Operator::Plus, Value::Numeric(NumericValue::F64(r))) => { + Value::Numeric(NumericValue::F64(l + r)) + } + (Value::Numeric(NumericValue::F64(l)), ast::Operator::Minus, Value::Numeric(NumericValue::F64(r))) => { + Value::Numeric(NumericValue::F64(l - r)) + } + (Value::Numeric(NumericValue::F64(l)), ast::Operator::Mul, Value::Numeric(NumericValue::F64(r))) => { + Value::Numeric(NumericValue::F64(l * r)) + } + (Value::Numeric(NumericValue::F64(l)), ast::Operator::Div, Value::Numeric(NumericValue::F64(r))) => { + Value::Numeric(NumericValue::F64(l / r)) + } //fail _ => panic!(""), diff --git a/src/main.rs b/src/main.rs index 7071e73..368e1bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,35 +27,18 @@ fn main() { .help("Sets an output file") .takes_value(true), ) - .arg( - Arg::with_name("INPUT") - .help("Sets the input file") - .required(true) - .index(1), - ) - .arg( - Arg::with_name("v") - .short("v") - .multiple(true) - .help("Sets the level of verbosity"), - ) + .arg(Arg::with_name("INPUT").help("Sets the input file").required(true).index(1)) + .arg(Arg::with_name("v").short("v").multiple(true).help("Sets the level of verbosity")) .get_matches(); let input = matches.value_of("INPUT").unwrap(); - let default_output = input - .rsplitn(2, ".") - .collect::>() - .last() - .unwrap() - .clone(); + let default_output = input.rsplitn(2, ".").collect::>().last().unwrap().clone(); let _output = matches.value_of("OUTPUT").unwrap_or(default_output); let contents = fs::read_to_string(input).expect("input file not found"); let unknown_id_gen = ast::IdGenerator::new(); - let module_ast = grammar::ModuleParser::new() - .parse(&unknown_id_gen, &contents) - .unwrap(); //TODO: convert to error - // println!("ast: {:#?}", &module_ast); + let module_ast = grammar::ModuleParser::new().parse(&unknown_id_gen, &contents).unwrap(); //TODO: convert to error + // println!("ast: {:#?}", &module_ast); let alias_resolver = type_alias_resolution::TypeAliasResolver {}; let resolved_ast = alias_resolver.with_module(&module_ast); // println!("resolved ast: {:#?}", &resolved_ast); @@ -85,47 +68,25 @@ fn main() { #[test] fn grammar() { let id_gen = ast::IdGenerator::new(); - assert!(grammar::LiteralIntParser::new() - .parse(&id_gen, "22") - .is_ok()); - assert!(grammar::IdentifierParser::new() - .parse(&id_gen, "foo") - .is_ok()); - assert!(grammar::LiteralIntParser::new() - .parse(&id_gen, "2a") - .is_err()); + assert!(grammar::LiteralIntParser::new().parse(&id_gen, "22").is_ok()); + assert!(grammar::IdentifierParser::new().parse(&id_gen, "foo").is_ok()); + assert!(grammar::LiteralIntParser::new().parse(&id_gen, "2a").is_err()); assert!(grammar::TermParser::new().parse(&id_gen, "22").is_ok()); assert!(grammar::TermParser::new().parse(&id_gen, "foo").is_ok()); - assert!(grammar::ExpressionParser::new() - .parse(&id_gen, "22 * foo") - .is_ok()); - assert!(grammar::ExpressionParser::new() - .parse(&id_gen, "22 * 33") - .is_ok()); - assert!(grammar::ExpressionParser::new() - .parse(&id_gen, "(22 * 33) + 24") - .is_ok()); + assert!(grammar::ExpressionParser::new().parse(&id_gen, "22 * foo").is_ok()); + assert!(grammar::ExpressionParser::new().parse(&id_gen, "22 * 33").is_ok()); + assert!(grammar::ExpressionParser::new().parse(&id_gen, "(22 * 33) + 24").is_ok()); - assert!(grammar::BlockParser::new() - .parse(&id_gen, "{ (22 * 33) + 24 }") - .is_ok()); - assert!(grammar::BlockParser::new() - .parse(&id_gen, "{ (22 * 33) + 24; 25 }") - .is_ok()); + assert!(grammar::BlockParser::new().parse(&id_gen, "{ (22 * 33) + 24 }").is_ok()); + assert!(grammar::BlockParser::new().parse(&id_gen, "{ (22 * 33) + 24; 25 }").is_ok()); // assert!(grammar::BlockParser::new().parse("{ (22 * 33) + 24\n 24 }").is_ok()); assert!(grammar::BlockParser::new().parse(&id_gen, "{ }").is_ok()); - assert!(grammar::VariableDeclarationParser::new() - .parse(&id_gen, "foo: i32") - .is_ok()); - assert!(grammar::VariableDeclarationParser::new() - .parse(&id_gen, "foo") - .is_err()); - assert!(grammar::VariableDeclarationParser::new() - .parse(&id_gen, "1234") - .is_err()); + assert!(grammar::VariableDeclarationParser::new().parse(&id_gen, "foo: i32").is_ok()); + assert!(grammar::VariableDeclarationParser::new().parse(&id_gen, "foo").is_err()); + assert!(grammar::VariableDeclarationParser::new().parse(&id_gen, "1234").is_err()); assert!(grammar::FunctionParser::new() .parse(&id_gen, "fn add(a: i32, b: i32): i32 { a + b }") @@ -140,9 +101,7 @@ fn grammar() { .parse(&id_gen, "fn add(a: i32, b: i32): i32") .is_err()); - assert!(grammar::FunctionCallParser::new() - .parse(&id_gen, "foo(1, 2)") - .is_ok()); + assert!(grammar::FunctionCallParser::new().parse(&id_gen, "foo(1, 2)").is_ok()); assert!(grammar::ModuleParser::new() .parse(&id_gen, "fn add(a: i32, b: i32): i32 { a + b }") diff --git a/src/type_alias_resolution.rs b/src/type_alias_resolution.rs index c2cc489..6a7e298 100644 --- a/src/type_alias_resolution.rs +++ b/src/type_alias_resolution.rs @@ -33,11 +33,7 @@ fn process_type(ctx: &Context, type_: &ast::TypeUsage) -> ast::TypeUsage { } ast::TypeUsage::Function(function) => { return ast::TypeUsage::Function(ast::FunctionTypeUsage { - arguments: function - .arguments - .iter() - .map(|a| process_type(ctx, &a.clone())) - .collect(), + arguments: function.arguments.iter().map(|a| process_type(ctx, &a.clone())).collect(), return_type: Box::new(process_type(ctx, &function.return_type.clone())), }); } @@ -51,9 +47,7 @@ pub struct TypeAliasResolver {} impl TypeAliasResolver { pub fn with_module(self: &Self, module: &ast::Module) -> ast::Module { - let mut ctx = Context { - type_aliases: vec![], - }; + let mut ctx = Context { type_aliases: vec![] }; for item in module.items.iter() { match item { ast::ModuleItem::TypeDeclaration(ast::TypeDeclaration::Alias(alias)) => { @@ -68,17 +62,11 @@ impl TypeAliasResolver { .items .iter() .map(|item| match item { - ast::ModuleItem::Function(function) => { - ast::ModuleItem::Function(self.with_function(&ctx, function)) - } + ast::ModuleItem::Function(function) => ast::ModuleItem::Function(self.with_function(&ctx, function)), ast::ModuleItem::TypeDeclaration(type_declaration) => { - ast::ModuleItem::TypeDeclaration( - self.with_type_declaration(&ctx, type_declaration), - ) - } - ast::ModuleItem::Impl(impl_) => { - ast::ModuleItem::Impl(self.with_impl(&ctx, impl_)) + ast::ModuleItem::TypeDeclaration(self.with_type_declaration(&ctx, type_declaration)) } + ast::ModuleItem::Impl(impl_) => ast::ModuleItem::Impl(self.with_impl(&ctx, impl_)), }) .collect(), }; @@ -86,28 +74,27 @@ impl TypeAliasResolver { fn with_function(self: &Self, ctx: &Context, function: &ast::Function) -> ast::Function { return ast::Function { - declaration: ast::FunctionDeclaration { - name: function.declaration.name.clone(), - arguments: function - .declaration - .arguments - .iter() - .map(|arg| ast::VariableDeclaration { - name: arg.name.clone(), - type_: process_type(ctx, &arg.type_), - }) - .collect(), - return_type: process_type(ctx, &function.declaration.return_type), - }, + declaration: self.with_function_declaration(ctx, &function.declaration), block: self.with_block(ctx, &function.block), }; } - fn with_type_declaration( - self: &Self, - ctx: &Context, - type_declaration: &ast::TypeDeclaration, - ) -> ast::TypeDeclaration { + fn with_function_declaration(self: &Self, ctx: &Context, declaration: &ast::FunctionDeclaration) -> ast::FunctionDeclaration { + return ast::FunctionDeclaration { + name: declaration.name.clone(), + arguments: declaration + .arguments + .iter() + .map(|arg| ast::VariableDeclaration { + name: arg.name.clone(), + type_: process_type(ctx, &arg.type_), + }) + .collect(), + return_type: process_type(ctx, &declaration.return_type), + }; + } + + fn with_type_declaration(self: &Self, ctx: &Context, type_declaration: &ast::TypeDeclaration) -> ast::TypeDeclaration { match type_declaration { ast::TypeDeclaration::Struct(struct_) => { return ast::TypeDeclaration::Struct(self.with_struct_declaration(ctx, struct_)); @@ -118,14 +105,13 @@ impl TypeAliasResolver { ast::TypeDeclaration::Alias(alias) => { return ast::TypeDeclaration::Alias(alias.clone()); } + ast::TypeDeclaration::Trait(trait_) => { + return ast::TypeDeclaration::Trait(self.with_trait(ctx, trait_)); + } } } - fn with_struct_declaration( - self: &Self, - ctx: &Context, - struct_: &ast::StructTypeDeclaration, - ) -> ast::StructTypeDeclaration { + fn with_struct_declaration(self: &Self, ctx: &Context, struct_: &ast::StructTypeDeclaration) -> ast::StructTypeDeclaration { return ast::StructTypeDeclaration { name: struct_.name.clone(), fields: struct_ @@ -139,6 +125,34 @@ impl TypeAliasResolver { }; } + fn with_trait(self: &Self, ctx: &Context, trait_: &ast::TraitTypeDeclaration) -> ast::TraitTypeDeclaration { + let mut trait_ctx = ctx.clone(); + trait_ctx.type_aliases.push(ast::AliasTypeDeclaration { + name: ast::Identifier { + name: ast::Spanned { + span: ast::Span { left: 0, right: 0 }, //todo: figure out a sane value for these + value: "Self".to_string(), + }, + }, + replaces: ast::TypeUsage::Named(ast::NamedTypeUsage { name: trait_.name.clone() }), + }); + return ast::TraitTypeDeclaration { + name: trait_.name.clone(), + functions: trait_ + .functions + .iter() + .map(|f| match f { + ast::TraitItem::Function(function) => { + ast::TraitItem::Function(self.with_function(&trait_ctx, function)) + }, + ast::TraitItem::FunctionDeclaration(function_declaration) => { + ast::TraitItem::FunctionDeclaration(self.with_function_declaration(&trait_ctx, function_declaration)) + } + }) + .collect(), + }; + } + fn with_impl(self: &Self, ctx: &Context, impl_: &ast::Impl) -> ast::Impl { let mut impl_ctx = ctx.clone(); impl_ctx.type_aliases.push(ast::AliasTypeDeclaration { @@ -153,22 +167,15 @@ impl TypeAliasResolver { }), }); return ast::Impl { + trait_: impl_.trait_.clone(), struct_name: impl_.struct_name.clone(), - functions: impl_ - .functions - .iter() - .map(|f| self.with_function(&impl_ctx, f)) - .collect(), + functions: impl_.functions.iter().map(|f| self.with_function(&impl_ctx, f)).collect(), }; } fn with_block(self: &Self, ctx: &Context, block: &ast::Block) -> ast::Block { return ast::Block { - statements: block - .statements - .iter() - .map(|s| self.with_statement(ctx, s)) - .collect(), + statements: block.statements.iter().map(|s| self.with_statement(ctx, s)).collect(), type_: process_type(ctx, &block.type_), }; } @@ -182,9 +189,7 @@ impl TypeAliasResolver { return ast::Statement::Let(self.with_let_statement(ctx, let_statement)); } ast::Statement::Assignment(assignment_statement) => { - return ast::Statement::Assignment( - self.with_assignment_statement(ctx, assignment_statement), - ); + return ast::Statement::Assignment(self.with_assignment_statement(ctx, assignment_statement)); } ast::Statement::Expression(expression) => { return ast::Statement::Expression(self.with_expression(ctx, expression)); @@ -192,21 +197,13 @@ impl TypeAliasResolver { } } - fn with_return_statement( - self: &Self, - ctx: &Context, - statement: &ast::ReturnStatement, - ) -> ast::ReturnStatement { + fn with_return_statement(self: &Self, ctx: &Context, statement: &ast::ReturnStatement) -> ast::ReturnStatement { return ast::ReturnStatement { source: self.with_expression(ctx, &statement.source), }; } - fn with_let_statement( - self: &Self, - ctx: &Context, - statement: &ast::LetStatement, - ) -> ast::LetStatement { + fn with_let_statement(self: &Self, ctx: &Context, statement: &ast::LetStatement) -> ast::LetStatement { return ast::LetStatement { variable_name: statement.variable_name.clone(), expression: self.with_expression(ctx, &statement.expression), @@ -214,56 +211,38 @@ impl TypeAliasResolver { }; } - fn with_assignment_statement( - self: &Self, - ctx: &Context, - statement: &ast::AssignmentStatement, - ) -> ast::AssignmentStatement { + fn with_assignment_statement(self: &Self, ctx: &Context, statement: &ast::AssignmentStatement) -> ast::AssignmentStatement { return ast::AssignmentStatement { source: match &statement.source { - ast::AssignmentTarget::Variable(variable) => { - ast::AssignmentTarget::Variable(ast::VariableUsage { - name: variable.name.clone(), - type_: process_type(ctx, &variable.type_), - }) - } - ast::AssignmentTarget::StructAttr(struct_attr) => { - ast::AssignmentTarget::StructAttr(ast::StructGetter { - source: self.with_expression(ctx, &struct_attr.source), - attribute: struct_attr.attribute.clone(), - type_: process_type(ctx, &struct_attr.type_), - }) - } + ast::AssignmentTarget::Variable(variable) => ast::AssignmentTarget::Variable(ast::VariableUsage { + name: variable.name.clone(), + type_: process_type(ctx, &variable.type_), + }), + ast::AssignmentTarget::StructAttr(struct_attr) => ast::AssignmentTarget::StructAttr(ast::StructGetter { + source: self.with_expression(ctx, &struct_attr.source), + attribute: struct_attr.attribute.clone(), + type_: process_type(ctx, &struct_attr.type_), + }), }, expression: self.with_expression(ctx, &statement.expression), }; } - fn with_expression( - self: &Self, - ctx: &Context, - expression: &ast::Expression, - ) -> ast::Expression { + fn with_expression(self: &Self, ctx: &Context, expression: &ast::Expression) -> ast::Expression { return ast::Expression { subexpression: Box::new(match &*expression.subexpression { - ast::Subexpression::LiteralInt(literal_int) => { - ast::Subexpression::LiteralInt(ast::LiteralInt { - value: literal_int.value.clone(), - type_: process_type(ctx, &literal_int.type_), - }) - } - ast::Subexpression::LiteralFloat(literal_float) => { - ast::Subexpression::LiteralFloat(ast::LiteralFloat { - value: literal_float.value.clone(), - type_: process_type(ctx, &literal_float.type_), - }) - } - ast::Subexpression::LiteralBool(literal_bool) => { - ast::Subexpression::LiteralBool(ast::LiteralBool { - value: literal_bool.value.clone(), - type_: process_type(ctx, &literal_bool.type_), - }) - } + ast::Subexpression::LiteralInt(literal_int) => ast::Subexpression::LiteralInt(ast::LiteralInt { + value: literal_int.value.clone(), + type_: process_type(ctx, &literal_int.type_), + }), + ast::Subexpression::LiteralFloat(literal_float) => ast::Subexpression::LiteralFloat(ast::LiteralFloat { + value: literal_float.value.clone(), + type_: process_type(ctx, &literal_float.type_), + }), + ast::Subexpression::LiteralBool(literal_bool) => ast::Subexpression::LiteralBool(ast::LiteralBool { + value: literal_bool.value.clone(), + type_: process_type(ctx, &literal_bool.type_), + }), ast::Subexpression::LiteralStruct(literal_struct) => { let result = resolve_type( ctx, @@ -285,44 +264,30 @@ impl TypeAliasResolver { type_: process_type(ctx, &literal_struct.type_), }) } - ast::Subexpression::FunctionCall(function_call) => { - ast::Subexpression::FunctionCall(ast::FunctionCall { - source: self.with_expression(ctx, &function_call.source), - arguments: function_call - .arguments - .iter() - .map(|arg| self.with_expression(ctx, arg)) - .collect(), - type_: process_type(ctx, &function_call.type_), - }) - } - ast::Subexpression::VariableUsage(variable_usage) => { - ast::Subexpression::VariableUsage(ast::VariableUsage { - name: variable_usage.name.clone(), - type_: process_type(ctx, &variable_usage.type_), - }) - } - ast::Subexpression::If(if_expression) => { - ast::Subexpression::If(ast::IfExpression { - condition: self.with_expression(ctx, &if_expression.condition), - block: self.with_block(ctx, &if_expression.block), - else_: match &if_expression.else_ { - Some(else_) => Some(self.with_block(ctx, else_)), - None => None, - }, - type_: process_type(ctx, &if_expression.type_), - }) - } - ast::Subexpression::StructGetter(struct_getter) => { - ast::Subexpression::StructGetter(ast::StructGetter { - source: self.with_expression(ctx, &struct_getter.source), - attribute: struct_getter.attribute.clone(), - type_: process_type(ctx, &struct_getter.type_), - }) - } - ast::Subexpression::Block(block) => { - ast::Subexpression::Block(self.with_block(ctx, &block)) - } + ast::Subexpression::FunctionCall(function_call) => ast::Subexpression::FunctionCall(ast::FunctionCall { + source: self.with_expression(ctx, &function_call.source), + arguments: function_call.arguments.iter().map(|arg| self.with_expression(ctx, arg)).collect(), + type_: process_type(ctx, &function_call.type_), + }), + ast::Subexpression::VariableUsage(variable_usage) => ast::Subexpression::VariableUsage(ast::VariableUsage { + name: variable_usage.name.clone(), + type_: process_type(ctx, &variable_usage.type_), + }), + ast::Subexpression::If(if_expression) => ast::Subexpression::If(ast::IfExpression { + condition: self.with_expression(ctx, &if_expression.condition), + block: self.with_block(ctx, &if_expression.block), + else_: match &if_expression.else_ { + Some(else_) => Some(self.with_block(ctx, else_)), + None => None, + }, + type_: process_type(ctx, &if_expression.type_), + }), + ast::Subexpression::StructGetter(struct_getter) => ast::Subexpression::StructGetter(ast::StructGetter { + source: self.with_expression(ctx, &struct_getter.source), + attribute: struct_getter.attribute.clone(), + type_: process_type(ctx, &struct_getter.type_), + }), + ast::Subexpression::Block(block) => ast::Subexpression::Block(self.with_block(ctx, &block)), ast::Subexpression::Op(op) => ast::Subexpression::Op(ast::Operation { left: self.with_expression(ctx, &op.left), op: op.op.clone(), diff --git a/src/type_checking.rs b/src/type_checking.rs index f79b86d..ea6ad77 100644 --- a/src/type_checking.rs +++ b/src/type_checking.rs @@ -15,7 +15,7 @@ pub enum NamedEntity { #[derive(Debug, Clone, PartialEq, Eq)] struct Context { pub current_function_return: Option, - pub impls: HashMap, + pub impls: Vec, pub environment: HashMap, } @@ -23,137 +23,275 @@ fn create_builtins() -> HashMap { let mut result = HashMap::new(); result.insert( "i8".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "i8".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "i8".to_string(), + })), ); result.insert( "i16".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "i16".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "i16".to_string(), + })), ); result.insert( "i32".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "i32".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "i32".to_string(), + })), ); result.insert( "i64".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "i64".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "i64".to_string(), + })), ); result.insert( "isize".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "isize".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "isize".to_string(), + })), ); result.insert( "u8".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "u8".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "u8".to_string(), + })), ); result.insert( "u16".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "u16".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "u16".to_string(), + })), ); result.insert( "u32".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "u32".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "u32".to_string(), + })), ); result.insert( "u64".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "u64".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "u64".to_string(), + })), ); result.insert( "usize".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "usize".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "usize".to_string(), + })), ); result.insert( "f32".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "f32".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "f32".to_string(), + })), ); result.insert( "f64".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "f64".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "f64".to_string(), + })), ); result.insert( "bool".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "bool".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "bool".to_string(), + })), ); result.insert( "!".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "!".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "!".to_string(), + })), ); result.insert( "unit".to_string(), - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( - ast::PrimitiveTypeDeclaration { - name: "unit".to_string(), - }, - )), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive(ast::PrimitiveTypeDeclaration { + name: "unit".to_string(), + })), ); return result; } +enum StructAttr { + Field(ast::TypeUsage), + Method(ast::TypeUsage), +} + +fn get_struct_attr(ctx: &Context, struct_declaration: &ast::StructTypeDeclaration, attribute: &ast::Identifier) -> Result { + for field in struct_declaration.fields.iter() { + if field.name.name.value == attribute.name.value { + return Ok(StructAttr::Field(field.type_.clone())); + } + } + + let mut result = Vec::new(); + for impl_ in ctx.impls.iter() { + if &struct_declaration.name.name.value != &impl_.struct_name.name.value { + continue; + } + for method in impl_.functions.iter() { + if method.declaration.name.name.value == attribute.name.value { + let mut function_type = method.declaration.to_type(); + + // if the name of the type of the first argument == the class, remove the first arg + if method.declaration.arguments.len() > 0 { + match &method.declaration.arguments[0].type_ { + ast::TypeUsage::Named(named) => { + if named.name.name.value == struct_declaration.name.name.value { + function_type = method.declaration.to_method_type(); + } + } + _ => {} + }; + } + result.push(function_type); + } + } + } + // TODO: default trait impls + if result.len() == 0 { + return Err(errors::TypingError::UnknownFieldName { + identifier: attribute.clone(), + }); + } + if result.len() > 1 { + return Err(errors::TypingError::MultipleFieldName { + identifier: attribute.clone(), + }); + } + return Ok(StructAttr::Method(result[0].clone())); +} + +fn get_trait_attr(ctx: &Context, trait_declaration: &ast::TraitTypeDeclaration, attribute: &ast::Identifier) -> Result { + let mut result = Vec::new(); + for trait_item in trait_declaration.functions.iter() { + let declaration = match trait_item { + ast::TraitItem::Function(function) => &function.declaration, + ast::TraitItem::FunctionDeclaration(declaration) => declaration, + }; + if declaration.name.name.value == attribute.name.value { + let mut function_type = declaration.to_type(); + println!("foo: {:?}", declaration); + // if the name of the type of the first argument == the class, remove the first arg + if declaration.arguments.len() > 0 { + match &declaration.arguments[0].type_ { + ast::TypeUsage::Named(named) => { + if named.name.name.value == trait_declaration.name.name.value { + function_type = declaration.to_method_type(); + } + } + _ => {} + }; + } + result.push(function_type); + } + } + if result.len() == 0 { + return Err(errors::TypingError::UnknownFieldName { + identifier: attribute.clone(), + }); + } + if result.len() > 1 { + return Err(errors::TypingError::MultipleFieldName { + identifier: attribute.clone(), + }); + } + return Ok(StructAttr::Method(result[0].clone())); +} + +fn get_attr(ctx: &Context, source_type: &ast::TypeUsage, attribute: &ast::Identifier) -> Result { + match source_type { + ast::TypeUsage::Named(named) => { + match &ctx.environment[&named.name.name.value] { + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct(type_declaration)) => { + return get_struct_attr(ctx, type_declaration, attribute); + } + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Trait(type_declaration)) => { + return get_trait_attr(ctx, type_declaration, attribute); + } + _ => { + return Err(errors::TypingError::AttributeOfNonstruct { + identifier: attribute.clone(), + }); + // TODO: support builtins - float, int, etc. + } + } + } + ast::TypeUsage::Function(_) => { + return Err(errors::TypingError::NotAStructLiteral { + identifier: attribute.clone(), + }); + } + _ => { + panic!("tried to get attr of unknown"); + } + }; +} + +fn compare_struct_trait( + struct_: &ast::TypeUsage, + trait_: &ast::TypeUsage, + struct_name: &ast::Identifier, + trait_name: &ast::Identifier, +) -> Result<()> { + match struct_ { + ast::TypeUsage::Named(named) => match trait_ { + ast::TypeUsage::Named(trait_named) => { + if named.name.name.value == trait_named.name.name.value + || (named.name.name.value == struct_name.name.value && trait_named.name.name.value == trait_name.name.value) + { + return Ok(()); + } + return Err(errors::TypingError::TypeMismatch { + type_one: struct_.clone(), + type_two: trait_.clone(), + }); + } + ast::TypeUsage::Function(_) => { + return Err(errors::TypingError::TypeMismatch { + type_one: struct_.clone(), + type_two: trait_.clone(), + }); + } + _ => panic!("Unknown in function definition"), + }, + ast::TypeUsage::Function(function) => match trait_ { + ast::TypeUsage::Named(_) => { + return Err(errors::TypingError::TypeMismatch { + type_one: struct_.clone(), + type_two: trait_.clone(), + }); + } + ast::TypeUsage::Function(trait_function) => { + if function.arguments.len() != trait_function.arguments.len() { + return Err(errors::TypingError::TypeMismatch { + type_one: struct_.clone(), + type_two: trait_.clone(), + }); + } + for (i, _) in function.arguments.iter().enumerate() { + compare_struct_trait(&function.arguments[i], &trait_function.arguments[i], struct_name, trait_name)?; + } + compare_struct_trait(&function.return_type, &trait_function.return_type, struct_name, trait_name)?; + return Ok(()); + } + _ => panic!("Unknown in function definition"), + }, + _ => panic!("Unknown in function definition"), + } +} + impl Context { fn add_variable(&self, name: String, type_usage: &ast::TypeUsage) -> Context { let mut ctx = self.clone(); - ctx.environment - .insert(name.to_string(), NamedEntity::Variable(type_usage.clone())); + ctx.environment.insert(name.to_string(), NamedEntity::Variable(type_usage.clone())); return ctx; } @@ -204,11 +342,7 @@ fn type_exists(ctx: &Context, type_: &ast::TypeUsage) -> Result<()> { return Ok(result); } -fn apply_substitution( - ctx: &Context, - substitution: &SubstitutionMap, - type_: &ast::TypeUsage, -) -> Result { +fn apply_substitution(ctx: &Context, substitution: &SubstitutionMap, type_: &ast::TypeUsage) -> Result { let result = match type_ { ast::TypeUsage::Named(named) => ast::TypeUsage::Named(named.clone()), ast::TypeUsage::Unknown(unknown) => { @@ -225,11 +359,7 @@ fn apply_substitution( } ast::TypeUsage::Function(ast::FunctionTypeUsage { arguments: arguments, - return_type: Box::new(apply_substitution( - ctx, - substitution, - &function.return_type, - )?), + return_type: Box::new(apply_substitution(ctx, substitution, &function.return_type)?), }) } }; @@ -237,20 +367,12 @@ fn apply_substitution( return Ok(result); } -fn compose_substitutions( - ctx: &Context, - s1: &SubstitutionMap, - s2: &SubstitutionMap, -) -> Result { +fn compose_substitutions(ctx: &Context, s1: &SubstitutionMap, s2: &SubstitutionMap) -> Result { let mut result = SubstitutionMap::new(); for k in s2.keys() { result.insert(k.to_string(), apply_substitution(ctx, s1, &s2[k])?); } - return Ok(s1 - .into_iter() - .map(|(k, v)| (k.clone(), v.clone())) - .chain(result) - .collect()); + return Ok(s1.into_iter().map(|(k, v)| (k.clone(), v.clone())).chain(result).collect()); } fn unify(ctx: &Context, t1: &ast::TypeUsage, t2: &ast::TypeUsage) -> Result { @@ -339,13 +461,10 @@ fn contains(t: &ast::TypeUsage, name: &str) -> bool { pub struct TypeChecker {} impl TypeChecker { - pub fn with_module( - self: &Self, - module: &ast::Module, - ) -> Result<(ast::Module, SubstitutionMap)> { + pub fn with_module(self: &Self, module: &ast::Module) -> Result<(ast::Module, SubstitutionMap)> { let mut ctx = Context { environment: create_builtins(), - impls: HashMap::new(), + impls: Vec::new(), current_function_return: None, }; @@ -363,14 +482,15 @@ impl TypeChecker { NamedEntity::TypeDeclaration(ast::TypeDeclaration::Alias(alias.clone())), ); } + ast::ModuleItem::TypeDeclaration(ast::TypeDeclaration::Trait(trait_)) => { + ctx.environment.insert( + trait_.name.name.value.to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Trait(trait_.clone())), + ); + } ast::ModuleItem::Function(function) => { let function_type = ast::FunctionTypeUsage { - arguments: function - .declaration - .arguments - .iter() - .map(|arg| arg.type_.clone()) - .collect(), + arguments: function.declaration.arguments.iter().map(|arg| arg.type_.clone()).collect(), return_type: Box::new(function.declaration.return_type.clone()), }; ctx.environment.insert( @@ -379,8 +499,7 @@ impl TypeChecker { ); } ast::ModuleItem::Impl(impl_) => { - ctx.impls - .insert(impl_.struct_name.name.value.to_string(), impl_.clone()); + ctx.impls.push(impl_.clone()); } _ => {} } @@ -396,7 +515,7 @@ impl TypeChecker { ast::ModuleItem::Function(func) } ast::ModuleItem::TypeDeclaration(type_declaration) => { - let (ty_decl, ty_subst) = self.with_type_declaration(&ctx, type_declaration)?; + let (ty_decl, ty_subst) = self.with_type_declaration(&ctx, &subst, type_declaration)?; subst = compose_substitutions(&ctx, &subst, &ty_subst)?; ast::ModuleItem::TypeDeclaration(ty_decl) } @@ -411,37 +530,40 @@ impl TypeChecker { return Ok((result, subst)); } + fn with_function_declaration( + self: &Self, + ctx: &Context, + declaration: &ast::FunctionDeclaration, + ) -> Result { + for arg in declaration.arguments.iter() { + type_exists(ctx, &arg.type_)?; + } + type_exists(ctx, &declaration.return_type)?; + return Ok(declaration.clone()); + } + fn with_function( self: &Self, ctx: &Context, incoming_substitutions: &SubstitutionMap, function: &ast::Function, ) -> Result<(ast::Function, SubstitutionMap)> { + let declaration = self.with_function_declaration(ctx, &function.declaration)?; // add args to env - let mut function_ctx = - ctx.set_current_function_return(&function.declaration.return_type.clone()); - for arg in function.declaration.arguments.iter() { - type_exists(ctx, &arg.type_)?; - function_ctx = - function_ctx.add_variable(arg.name.name.value.to_string(), &arg.type_.clone()); + let mut function_ctx = ctx.set_current_function_return(&declaration.return_type.clone()); + for arg in declaration.arguments.iter() { + function_ctx = function_ctx.add_variable(arg.name.name.value.to_string(), &arg.type_.clone()); } - type_exists(ctx, &function.declaration.return_type)?; - let (block, substitution) = - self.with_block(&function_ctx, incoming_substitutions, &function.block)?; - let mut substitution = - compose_substitutions(&function_ctx, incoming_substitutions, &substitution)?; + let (block, substitution) = self.with_block(&function_ctx, incoming_substitutions, &function.block)?; + let mut substitution = compose_substitutions(&function_ctx, incoming_substitutions, &substitution)?; match &block.type_ { ast::TypeUsage::Named(named) => { if named.name.name.value != "!" { substitution = compose_substitutions( &function_ctx, &substitution, - &unify( - &function_ctx, - &function.declaration.return_type, - &block.type_, - )?, + &unify(&function_ctx, &declaration.return_type, &block.type_)?, )?; } } @@ -449,11 +571,7 @@ impl TypeChecker { substitution = compose_substitutions( &function_ctx, &substitution, - &unify( - &function_ctx, - &function.declaration.return_type, - &block.type_, - )?, + &unify(&function_ctx, &declaration.return_type, &block.type_)?, )?; } } @@ -461,14 +579,9 @@ impl TypeChecker { return Ok(( ast::Function { declaration: ast::FunctionDeclaration { - name: function.declaration.name.clone(), - arguments: function - .declaration - .arguments - .iter() - .map(|arg| arg.clone()) - .collect(), - return_type: function.declaration.return_type.clone(), + name: declaration.name.clone(), + arguments: declaration.arguments.iter().map(|arg| arg.clone()).collect(), + return_type: declaration.return_type.clone(), }, block: block, }, @@ -479,6 +592,7 @@ impl TypeChecker { fn with_type_declaration( self: &Self, ctx: &Context, + incoming_substitutions: &SubstitutionMap, type_declaration: &ast::TypeDeclaration, ) -> Result<(ast::TypeDeclaration, SubstitutionMap)> { match type_declaration { @@ -487,25 +601,49 @@ impl TypeChecker { return Ok((ast::TypeDeclaration::Struct(result), SubstitutionMap::new())); } ast::TypeDeclaration::Primitive(primitive) => { - return Ok(( - ast::TypeDeclaration::Primitive(primitive.clone()), - SubstitutionMap::new(), - )); + return Ok((ast::TypeDeclaration::Primitive(primitive.clone()), SubstitutionMap::new())); } ast::TypeDeclaration::Alias(alias) => { - return Ok(( - ast::TypeDeclaration::Alias(alias.clone()), - SubstitutionMap::new(), - )); + return Ok((ast::TypeDeclaration::Alias(alias.clone()), SubstitutionMap::new())); + } + ast::TypeDeclaration::Trait(trait_) => { + let (result, subst) = self.with_trait_declaration(ctx, incoming_substitutions, trait_)?; + return Ok((ast::TypeDeclaration::Trait(result), subst)); } } } - fn with_struct_declaration( + fn with_trait_declaration( self: &Self, ctx: &Context, - struct_: &ast::StructTypeDeclaration, - ) -> Result { + incoming_substitutions: &SubstitutionMap, + trait_: &ast::TraitTypeDeclaration, + ) -> Result<(ast::TraitTypeDeclaration, SubstitutionMap)> { + let mut substitutions = incoming_substitutions.clone(); + let mut result_functions = vec!(); + for item in &trait_.functions { + match item { + ast::TraitItem::FunctionDeclaration(declaration) => { + let result_declaration = self.with_function_declaration(ctx, declaration)?; + result_functions.push(ast::TraitItem::FunctionDeclaration(result_declaration)); + } + ast::TraitItem::Function(function) => { + let (function_result, susbt) = self.with_function(ctx, incoming_substitutions, function)?; + substitutions = compose_substitutions(ctx, &substitutions, &susbt)?; + result_functions.push(ast::TraitItem::Function(function_result)); + } + } + } + Ok(( + ast::TraitTypeDeclaration { + name: trait_.name.clone(), + functions: result_functions, + }, + substitutions, + )) + } + + fn with_struct_declaration(self: &Self, ctx: &Context, struct_: &ast::StructTypeDeclaration) -> Result { let mut fields = vec![]; for field in struct_.fields.iter() { type_exists(ctx, &field.type_)?; @@ -534,8 +672,75 @@ impl TypeChecker { substitutions = compose_substitutions(ctx, &substitutions, &function_subs)?; functions.push(result); } + // See if trait actually matches + match &impl_.trait_ { + Some(trait_) => { + // assert trait functions satisfied + if !ctx.environment.contains_key(&trait_.name.value) { + return Err(errors::TypingError::TypeDoesNotExist { + identifier: trait_.clone(), + }); + } + let trait_declaration = match &ctx.environment[&trait_.name.value] { + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Trait(declaration)) => declaration, + _ => { + return Err(errors::TypingError::ImplTraitMustBeTrait { + trait_name: trait_.clone(), + }); + } + }; + for trait_item in trait_declaration.functions.iter() { + match trait_item { + ast::TraitItem::FunctionDeclaration(declaration) => { + let mut found = false; + for impl_function in impl_.functions.iter() { + if impl_function.declaration.name.name.value == declaration.name.name.value { + found = true; + compare_struct_trait(&impl_function.declaration.to_type(), &declaration.to_type(), &impl_.struct_name, &trait_)?; + } + } + if found == false { + return Err(errors::TypingError::MissingTraitFunction { + struct_name: impl_.struct_name.clone(), + function_name: declaration.name.clone(), + }); + } + } + ast::TraitItem::Function(function) => { + // skip found check because it has a default + for impl_function in impl_.functions.iter() { + if impl_function.declaration.name.name.value == function.declaration.name.name.value { + compare_struct_trait(&impl_function.declaration.to_type(), &function.declaration.to_type(), &impl_.struct_name, &trait_)?; + } + } + } + } + } + // assert all functions are in trait + for impl_function in impl_.functions.iter() { + let mut found = false; + for trait_item in trait_declaration.functions.iter() { + let declaration = match trait_item { + ast::TraitItem::Function(function) => &function.declaration, + ast::TraitItem::FunctionDeclaration(declaration) => declaration, + }; + if impl_function.declaration.name.name.value == declaration.name.name.value { + found = true; + break; + } + } + if found == false { + return Err(errors::TypingError::FunctionNotInTrait { + function_name: impl_function.declaration.name.clone(), + }); + } + } + } + None => {} + } return Ok(( ast::Impl { + trait_: impl_.trait_.clone(), struct_name: impl_.struct_name.clone(), functions: functions, }, @@ -564,28 +769,18 @@ impl TypeChecker { } let mut statements = vec![]; for s in block.statements.iter() { - let (statement_ctx, result, statement_substitutions) = - self.with_statement(&block_ctx, &substitutions, s)?; + let (statement_ctx, result, statement_substitutions) = self.with_statement(&block_ctx, &substitutions, s)?; block_ctx = statement_ctx; - substitutions = - compose_substitutions(&block_ctx, &substitutions, &statement_substitutions)?; + substitutions = compose_substitutions(&block_ctx, &substitutions, &statement_substitutions)?; statements.push(result); } if !has_return { match block.statements.last().unwrap() { ast::Statement::Expression(expr) => { - substitutions = compose_substitutions( - &block_ctx, - &substitutions, - &unify(&block_ctx, &block.type_, &expr.type_)?, - )?; + substitutions = compose_substitutions(&block_ctx, &substitutions, &unify(&block_ctx, &block.type_, &expr.type_)?)?; } _ => { - substitutions = compose_substitutions( - &block_ctx, - &substitutions, - &unify(&block_ctx, &block.type_, &ast::new_unit())?, - )?; + substitutions = compose_substitutions(&block_ctx, &substitutions, &unify(&block_ctx, &block.type_, &ast::new_unit())?)?; } } } @@ -609,29 +804,22 @@ impl TypeChecker { ) -> Result<(Context, ast::Statement, SubstitutionMap)> { match statement { ast::Statement::Return(return_statement) => { - let (result, subst) = - self.with_return_statement(ctx, incoming_substitutions, return_statement)?; + let (result, subst) = self.with_return_statement(ctx, incoming_substitutions, return_statement)?; let subst = compose_substitutions(ctx, &incoming_substitutions, &subst)?; return Ok((ctx.clone(), ast::Statement::Return(result), subst)); } ast::Statement::Let(let_statement) => { - let (let_ctx, result, subst) = - self.with_let_statement(ctx, incoming_substitutions, let_statement)?; + let (let_ctx, result, subst) = self.with_let_statement(ctx, incoming_substitutions, let_statement)?; let subst = compose_substitutions(ctx, &incoming_substitutions, &subst)?; return Ok((let_ctx, ast::Statement::Let(result), subst)); } ast::Statement::Assignment(assignment_statement) => { - let (result, subst) = self.with_assignment_statement( - ctx, - incoming_substitutions, - assignment_statement, - )?; + let (result, subst) = self.with_assignment_statement(ctx, incoming_substitutions, assignment_statement)?; let subst = compose_substitutions(ctx, &incoming_substitutions, &subst)?; return Ok((ctx.clone(), ast::Statement::Assignment(result), subst)); } ast::Statement::Expression(expression) => { - let (result, subst) = - self.with_expression(ctx, incoming_substitutions, expression)?; + let (result, subst) = self.with_expression(ctx, incoming_substitutions, expression)?; let subst = compose_substitutions(ctx, &incoming_substitutions, &subst)?; return Ok((ctx.clone(), ast::Statement::Expression(result), subst)); } @@ -644,8 +832,7 @@ impl TypeChecker { incoming_substitutions: &SubstitutionMap, statement: &ast::ReturnStatement, ) -> Result<(ast::ReturnStatement, SubstitutionMap)> { - let (result, subst) = - self.with_expression(ctx, incoming_substitutions, &statement.source)?; + let (result, subst) = self.with_expression(ctx, incoming_substitutions, &statement.source)?; let mut substitution = compose_substitutions(ctx, &incoming_substitutions, &subst)?; let mut is_never = false; match &result.type_ { @@ -660,11 +847,7 @@ impl TypeChecker { substitution = compose_substitutions( ctx, &subst, - &unify( - ctx, - &ctx.current_function_return.as_ref().unwrap(), - &result.type_, - )?, + &unify(ctx, &ctx.current_function_return.as_ref().unwrap(), &result.type_)?, )?; } @@ -677,11 +860,9 @@ impl TypeChecker { incoming_substitutions: &SubstitutionMap, statement: &ast::LetStatement, ) -> Result<(Context, ast::LetStatement, SubstitutionMap)> { - let (result, subst) = - self.with_expression(ctx, incoming_substitutions, &statement.expression)?; + let (result, subst) = self.with_expression(ctx, incoming_substitutions, &statement.expression)?; let let_ctx = ctx.add_variable(statement.variable_name.name.value.clone(), &result.type_); - let substitution = - compose_substitutions(ctx, &subst, &unify(ctx, &statement.type_, &result.type_)?)?; + let substitution = compose_substitutions(ctx, &subst, &unify(ctx, &statement.type_, &result.type_)?)?; return Ok(( let_ctx, ast::LetStatement { @@ -699,66 +880,32 @@ impl TypeChecker { incoming_substitutions: &SubstitutionMap, statement: &ast::AssignmentStatement, ) -> Result<(ast::AssignmentStatement, SubstitutionMap)> { - let (expr, subst) = - self.with_expression(ctx, incoming_substitutions, &statement.expression)?; + let (expr, subst) = self.with_expression(ctx, incoming_substitutions, &statement.expression)?; let mut substitution = compose_substitutions(ctx, &incoming_substitutions, &subst)?; let result_as = ast::AssignmentStatement { source: match &statement.source { ast::AssignmentTarget::Variable(variable) => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &variable.type_, &expr.type_)?, - )?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &variable.type_, &expr.type_)?)?; ast::AssignmentTarget::Variable(ast::VariableUsage { name: variable.name.clone(), type_: apply_substitution(ctx, &substitution, &variable.type_)?, }) } ast::AssignmentTarget::StructAttr(struct_attr) => { - let (source, subst) = - self.with_expression(ctx, &substitution, &struct_attr.source)?; + let (source, subst) = self.with_expression(ctx, &substitution, &struct_attr.source)?; let mut subst = subst.clone(); - match &source.type_ { - ast::TypeUsage::Named(named) => { - match &ctx.environment[&named.name.name.value] { - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct( - type_declaration, - )) => { - let mut found = false; - for field in type_declaration.fields.iter() { - if field.name.name.value == struct_attr.attribute.name.value - { - found = true; - subst = compose_substitutions( - ctx, - &subst, - &unify(ctx, &struct_attr.type_, &field.type_)?, - )?; - } - } - if !found { - return Err(errors::TypingError::UnknownFieldName { - identifier: struct_attr.attribute.clone(), - }); - } - } - _ => { - return Err(errors::TypingError::AttributeOfNonstruct { - identifier: struct_attr.attribute.clone(), - }); - } - } - } - ast::TypeUsage::Function(_) => { - return Err(errors::TypingError::NotAStructLiteral { + let field_type = match get_attr(ctx, &source.type_, &struct_attr.attribute)? { + StructAttr::Field(type_) => type_, + StructAttr::Method(_) => { + return Err(errors::TypingError::CannotAssignToMethod { identifier: struct_attr.attribute.clone(), - }); + }) } - _ => {} // skip unifying if struct type is unknown1 - } + }; + + subst = compose_substitutions(ctx, &subst, &unify(ctx, &struct_attr.type_, &field_type)?)?; let substitution = compose_substitutions( ctx, @@ -786,422 +933,65 @@ impl TypeChecker { let mut substitution = incoming_substitutions.clone(); let subexpression = Box::new(match &*expression.subexpression { ast::Subexpression::LiteralInt(literal_int) => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &literal_int.type_)?, - )?; - ast::Subexpression::LiteralInt(ast::LiteralInt { - value: literal_int.value.clone(), - type_: apply_substitution(ctx, &substitution, &literal_int.type_)?, - }) + let (result, subst) = self.with_literal_int(ctx, &substitution, literal_int)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.type_)?)?; + ast::Subexpression::LiteralInt(result) } ast::Subexpression::LiteralFloat(literal_float) => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &literal_float.type_)?, - )?; - ast::Subexpression::LiteralFloat(ast::LiteralFloat { - value: literal_float.value.clone(), - type_: apply_substitution(ctx, &substitution, &literal_float.type_)?, - }) + let (result, subst) = self.with_literal_float(ctx, &substitution, literal_float)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.type_)?)?; + ast::Subexpression::LiteralFloat(result) } ast::Subexpression::LiteralBool(literal_bool) => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &literal_bool.type_)?, - )?; - ast::Subexpression::LiteralBool(ast::LiteralBool { - value: literal_bool.value.clone(), - type_: apply_substitution(ctx, &substitution, &literal_bool.type_)?, - }) + let (result, subst) = self.with_literal_bool(ctx, &substitution, literal_bool)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.type_)?)?; + ast::Subexpression::LiteralBool(result) } ast::Subexpression::LiteralStruct(literal_struct) => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &literal_struct.type_)?, - )?; - let type_declaration = match &ctx.environment[&literal_struct.name.name.value] { - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct( - type_declaration, - )) => type_declaration, - _ => { - return Err(errors::TypingError::NotAStructLiteral { - identifier: literal_struct.name.clone(), - }); - } - }; - if type_declaration.fields.len() != literal_struct.fields.len() { - return Err(errors::TypingError::StructLiteralFieldsMismatch { - struct_name: literal_struct.name.clone(), - struct_definition_name: type_declaration.name.clone(), - }); - } - let mut fields = vec![]; - for type_field in type_declaration.fields.iter() { - let mut found = false; - let mut field_expression: Option = None; - for field in literal_struct.fields.iter() { - if type_field.name.name.value == field.0.name.value { - found = true; - let (result, subst) = - self.with_expression(ctx, &substitution, &field.1)?; - substitution = compose_substitutions(ctx, &substitution, &subst)?; - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &type_field.type_, &result.type_)?, - )?; - field_expression = Some(result); - } - } - if !found { - return Err(errors::TypingError::StructLiteralFieldsMismatch { - struct_name: literal_struct.name.clone(), - struct_definition_name: type_field.name.clone(), - }); - } - fields.push((type_field.name.clone(), field_expression.unwrap())); - } - ast::Subexpression::LiteralStruct(ast::LiteralStruct { - name: literal_struct.name.clone(), - fields: fields, - type_: apply_substitution(ctx, &substitution, &literal_struct.type_)?, - }) + let (result, subst) = self.with_literal_struct(ctx, &substitution, literal_struct)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.type_)?)?; + ast::Subexpression::LiteralStruct(result) } ast::Subexpression::FunctionCall(function_call) => { - let (source, subst) = - self.with_expression(ctx, &substitution, &function_call.source)?; + let (result, subst) = self.with_function_call(ctx, &substitution, function_call)?; substitution = compose_substitutions(ctx, &substitution, &subst)?; - match &source.type_ { - ast::TypeUsage::Function(fn_type) => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &function_call.type_, &*fn_type.return_type)?, - )?; - if function_call.arguments.len() != fn_type.arguments.len() { - return Err(errors::TypingError::ArgumentLengthMismatch {}); - } - } - ast::TypeUsage::Named(_) => { - return Err(errors::TypingError::FunctionCallNotAFunction {}); - } - _ => {} - } - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &function_call.type_)?, - )?; - let mut arguments = vec![]; - for (i, arg) in function_call.arguments.iter().enumerate() { - let (result, subst) = self.with_expression(ctx, &substitution, arg)?; - substitution = compose_substitutions(ctx, &substitution, &subst)?; - - match &source.type_ { - ast::TypeUsage::Function(fn_type) => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &fn_type.arguments[i], &result.type_)?, - )?; - } - ast::TypeUsage::Named(_) => { - return Err(errors::TypingError::FunctionCallNotAFunction {}); - } - _ => {} - } - arguments.push(result); - } - ast::Subexpression::FunctionCall(ast::FunctionCall { - source: source.clone(), - arguments: arguments, - type_: apply_substitution(ctx, &substitution, &function_call.type_)?, - }) + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &function_call.type_)?)?; + ast::Subexpression::FunctionCall(result) } ast::Subexpression::VariableUsage(variable_usage) => { - match &ctx.environment[&variable_usage.name.name.value] { - NamedEntity::TypeDeclaration(_) => { - panic!("Using types not yet supported"); - } - NamedEntity::Variable(variable) => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &variable_usage.type_, &variable)?, - )?; - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &variable_usage.type_)?, - )?; - } - } - ast::Subexpression::VariableUsage(ast::VariableUsage { - name: variable_usage.name.clone(), - type_: apply_substitution(ctx, &substitution, &variable_usage.type_)?, - }) + let (result, subst) = self.with_variable_usage(ctx, &substitution, variable_usage)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &variable_usage.type_)?)?; + ast::Subexpression::VariableUsage(result) } ast::Subexpression::If(if_expression) => { - let (condition, subst) = - self.with_expression(ctx, &substitution, &if_expression.condition)?; + let (result, subst) = self.with_if(ctx, &substitution, if_expression)?; substitution = compose_substitutions(ctx, &substitution, &subst)?; - - let (block_result, subst) = - self.with_block(ctx, &substitution, &if_expression.block)?; - substitution = compose_substitutions(ctx, &substitution, &subst)?; - - let else_ = match &if_expression.else_ { - Some(else_) => { - let (result, subst) = self.with_block(ctx, &substitution, else_)?; - substitution = compose_substitutions(ctx, &substitution, &subst)?; - Some(result) - } - None => None, - }; - - match &condition.type_ { - ast::TypeUsage::Named(named) => { - if named.name.name.value != "bool" { - return Err(errors::TypingError::IfConditionMustBeBool {}); - } - } - ast::TypeUsage::Function(_) => { - return Err(errors::TypingError::IfConditionMustBeBool {}); - } - _ => {} - }; - - let mut never_count = 0; - match &block_result.type_ { - ast::TypeUsage::Named(named) => { - if named.name.name.value != "!" { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &if_expression.type_, &block_result.type_)?, - )?; - } else { - never_count += 1; - } - } - _ => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &if_expression.type_, &block_result.type_)?, - )?; - } - }; - - match &else_ { - Some(else_block) => { - match &else_block.type_ { - ast::TypeUsage::Named(named) => { - if named.name.name.value != "!" { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &if_expression.type_, &else_block.type_)?, - )?; - } else { - never_count += 1; - } - } - _ => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &if_expression.type_, &else_block.type_)?, - )?; - } - }; - } - None => { - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &if_expression.type_, &ast::new_unit())?, - )?; - } - } - - let result_type = if never_count == 2 { - ast::new_never() - } else { - apply_substitution(ctx, &substitution, &if_expression.type_)? - }; - - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &result_type)?, - )?; - - ast::Subexpression::If(ast::IfExpression { - condition: condition, - block: block_result, - else_: else_, - type_: result_type, - }) + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.type_)?)?; + ast::Subexpression::If(result) } ast::Subexpression::StructGetter(struct_getter) => { - let (source, subst) = - self.with_expression(ctx, &substitution, &struct_getter.source)?; + let (result, subst) = self.with_struct_getter(ctx, &substitution, struct_getter)?; substitution = compose_substitutions(ctx, &substitution, &subst)?; - - match &source.type_ { - ast::TypeUsage::Named(named) => { - match &ctx.environment[&named.name.name.value] { - NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct( - type_declaration, - )) => { - let mut found = false; - for field in type_declaration.fields.iter() { - if field.name.name.value == struct_getter.attribute.name.value { - found = true; - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &struct_getter.type_, &field.type_)?, - )?; - } - } - if !found { - for method in ctx.impls[&type_declaration.name.name.value] - .functions - .iter() - { - if method.declaration.name.name.value - == struct_getter.attribute.name.value - { - let mut function_type = ast::FunctionTypeUsage { - arguments: method - .declaration - .arguments - .iter() - .map(|arg| arg.type_.clone()) - .collect(), - return_type: Box::new( - method.declaration.return_type.clone(), - ), - }; - // if the name of the type of the first argument == the class, remove the first arg - if function_type.arguments.len() > 0 { - match &function_type.arguments[0] { - ast::TypeUsage::Named(named) => { - if named.name.name.value - == type_declaration.name.name.value - { - function_type = - ast::FunctionTypeUsage { - arguments: method - .declaration - .arguments - [1..method - .declaration - .arguments - .len()] - .iter() - .map(|arg| { - arg.type_.clone() - }) - .collect(), - return_type: Box::new( - method - .declaration - .return_type - .clone(), - ), - }; - } - } - _ => {} - }; - } - - substitution = compose_substitutions( - ctx, - &substitution, - &unify( - ctx, - &struct_getter.type_, - &ast::TypeUsage::Function(function_type), - )?, - )?; - found = true; - } - } - } - if !found { - return Err(errors::TypingError::UnknownFieldName { - identifier: struct_getter.attribute.clone(), - }); - } - } - _ => { - return Err(errors::TypingError::AttributeOfNonstruct { - identifier: struct_getter.attribute.clone(), - }); - // TODO: support builtins - } - } - } - ast::TypeUsage::Function(_) => { - return Err(errors::TypingError::NotAStructLiteral { - identifier: struct_getter.attribute.clone(), - }); - } - _ => {} // skip unifying if struct type is unknown1 - } - - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &struct_getter.type_)?, - )?; - - ast::Subexpression::StructGetter(ast::StructGetter { - source: source, - attribute: struct_getter.attribute.clone(), - type_: apply_substitution(ctx, &substitution, &struct_getter.type_)?, - }) + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.type_)?)?; + ast::Subexpression::StructGetter(result) } ast::Subexpression::Block(block) => { - let (result, subst) = self.with_block(ctx, &substitution, &block)?; + let (result, subst) = self.with_block_expression(ctx, &substitution, block)?; substitution = compose_substitutions(ctx, &substitution, &subst)?; - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &result.type_)?, - )?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.type_)?)?; ast::Subexpression::Block(result) } ast::Subexpression::Op(op) => { - let (expr_left, subst_left) = self.with_expression(ctx, &substitution, &op.left)?; - let (expr_right, subst_right) = - self.with_expression(ctx, &substitution, &op.right)?; - substitution = compose_substitutions(ctx, &substitution, &subst_left)?; - substitution = compose_substitutions(ctx, &substitution, &subst_right)?; - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &expr_left.type_)?, - )?; - substitution = compose_substitutions( - ctx, - &substitution, - &unify(ctx, &expression.type_, &expr_right.type_)?, - )?; - ast::Subexpression::Op(ast::Operation { - left: expr_left, - op: op.op.clone(), - right: expr_right, - }) + let (result, subst) = self.with_op(ctx, &substitution, op)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.left.type_)?)?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &expression.type_, &result.right.type_)?)?; + ast::Subexpression::Op(result) } }); @@ -1211,4 +1001,309 @@ impl TypeChecker { }; return Ok((expr, substitution)); } + + fn with_literal_int( + self: &Self, + ctx: &Context, + substitution: &SubstitutionMap, + literal_int: &ast::LiteralInt, + ) -> Result<(ast::LiteralInt, SubstitutionMap)> { + Ok(( + ast::LiteralInt { + value: literal_int.value.clone(), + type_: apply_substitution(ctx, &substitution, &literal_int.type_)?, + }, + substitution.clone(), + )) + } + + fn with_literal_float( + self: &Self, + ctx: &Context, + substitution: &SubstitutionMap, + literal_float: &ast::LiteralFloat, + ) -> Result<(ast::LiteralFloat, SubstitutionMap)> { + Ok(( + ast::LiteralFloat { + value: literal_float.value.clone(), + type_: apply_substitution(ctx, &substitution, &literal_float.type_)?, + }, + substitution.clone(), + )) + } + + fn with_literal_bool( + self: &Self, + ctx: &Context, + substitution: &SubstitutionMap, + literal_bool: &ast::LiteralBool, + ) -> Result<(ast::LiteralBool, SubstitutionMap)> { + Ok(( + ast::LiteralBool { + value: literal_bool.value.clone(), + type_: apply_substitution(ctx, &substitution, &literal_bool.type_)?, + }, + substitution.clone(), + )) + } + + fn with_literal_struct( + self: &Self, + ctx: &Context, + substitution: &SubstitutionMap, + literal_struct: &ast::LiteralStruct, + ) -> Result<(ast::LiteralStruct, SubstitutionMap)> { + let mut substitution = substitution.clone(); + let type_declaration = match &ctx.environment[&literal_struct.name.name.value] { + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct(type_declaration)) => type_declaration, + _ => { + return Err(errors::TypingError::NotAStructLiteral { + identifier: literal_struct.name.clone(), + }); + } + }; + if type_declaration.fields.len() != literal_struct.fields.len() { + return Err(errors::TypingError::StructLiteralFieldsMismatch { + struct_name: literal_struct.name.clone(), + struct_definition_name: type_declaration.name.clone(), + }); + } + let mut fields = vec![]; + for type_field in type_declaration.fields.iter() { + let mut found = false; + let mut field_expression: Option = None; + for field in literal_struct.fields.iter() { + if type_field.name.name.value == field.0.name.value { + found = true; + let (result, subst) = self.with_expression(ctx, &substitution, &field.1)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &type_field.type_, &result.type_)?)?; + field_expression = Some(result); + } + } + if !found { + return Err(errors::TypingError::StructLiteralFieldsMismatch { + struct_name: literal_struct.name.clone(), + struct_definition_name: type_field.name.clone(), + }); + } + fields.push((type_field.name.clone(), field_expression.unwrap())); + } + Ok(( + ast::LiteralStruct { + name: literal_struct.name.clone(), + fields: fields, + type_: apply_substitution(ctx, &substitution, &literal_struct.type_)?, + }, + substitution, + )) + } + + fn with_function_call( + self: &Self, + ctx: &Context, + substitution: &SubstitutionMap, + function_call: &ast::FunctionCall, + ) -> Result<(ast::FunctionCall, SubstitutionMap)> { + let mut substitution = substitution.clone(); + let (source, subst) = self.with_expression(ctx, &substitution, &function_call.source)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + match &source.type_ { + ast::TypeUsage::Function(fn_type) => { + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &function_call.type_, &*fn_type.return_type)?)?; + if function_call.arguments.len() != fn_type.arguments.len() { + println!("{:?}\n{:?}", &function_call, &fn_type); + return Err(errors::TypingError::ArgumentLengthMismatch {}); + } + } + ast::TypeUsage::Named(_) => { + return Err(errors::TypingError::FunctionCallNotAFunction {}); + } + _ => {} + } + let mut arguments = vec![]; + for (i, arg) in function_call.arguments.iter().enumerate() { + let (result, subst) = self.with_expression(ctx, &substitution, arg)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + + match &source.type_ { + ast::TypeUsage::Function(fn_type) => { + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &fn_type.arguments[i], &result.type_)?)?; + } + ast::TypeUsage::Named(_) => { + return Err(errors::TypingError::FunctionCallNotAFunction {}); + } + _ => {} + } + arguments.push(result); + } + Ok(( + ast::FunctionCall { + source: source.clone(), + arguments: arguments, + type_: apply_substitution(ctx, &substitution, &function_call.type_)?, + }, + substitution, + )) + } + + fn with_variable_usage( + self: &Self, + ctx: &Context, + substitution: &SubstitutionMap, + variable_usage: &ast::VariableUsage, + ) -> Result<(ast::VariableUsage, SubstitutionMap)> { + let mut substitution = substitution.clone(); + match &ctx.environment[&variable_usage.name.name.value] { + NamedEntity::TypeDeclaration(_) => { + panic!("Using types not yet supported"); + } + NamedEntity::Variable(variable) => { + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &variable_usage.type_, &variable)?)?; + } + } + Ok(( + ast::VariableUsage { + name: variable_usage.name.clone(), + type_: apply_substitution(ctx, &substitution, &variable_usage.type_)?, + }, + substitution, + )) + } + + fn with_if( + self: &Self, + ctx: &Context, + substitution: &SubstitutionMap, + if_expression: &ast::IfExpression, + ) -> Result<(ast::IfExpression, SubstitutionMap)> { + let mut substitution = substitution.clone(); + let (condition, subst) = self.with_expression(ctx, &substitution, &if_expression.condition)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + + let (block_result, subst) = self.with_block(ctx, &substitution, &if_expression.block)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + + let else_ = match &if_expression.else_ { + Some(else_) => { + let (result, subst) = self.with_block(ctx, &substitution, else_)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + Some(result) + } + None => None, + }; + + match &condition.type_ { + ast::TypeUsage::Named(named) => { + if named.name.name.value != "bool" { + return Err(errors::TypingError::IfConditionMustBeBool {}); + } + } + ast::TypeUsage::Function(_) => { + return Err(errors::TypingError::IfConditionMustBeBool {}); + } + _ => {} + }; + + let mut never_count = 0; + match &block_result.type_ { + ast::TypeUsage::Named(named) => { + if named.name.name.value != "!" { + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &if_expression.type_, &block_result.type_)?)?; + } else { + never_count += 1; + } + } + _ => { + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &if_expression.type_, &block_result.type_)?)?; + } + }; + + match &else_ { + Some(else_block) => { + match &else_block.type_ { + ast::TypeUsage::Named(named) => { + if named.name.name.value != "!" { + substitution = + compose_substitutions(ctx, &substitution, &unify(ctx, &if_expression.type_, &else_block.type_)?)?; + } else { + never_count += 1; + } + } + _ => { + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &if_expression.type_, &else_block.type_)?)?; + } + }; + } + None => { + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &if_expression.type_, &ast::new_unit())?)?; + } + } + + let result_type = if never_count == 2 { + ast::new_never() + } else { + apply_substitution(ctx, &substitution, &if_expression.type_)? + }; + + Ok(( + ast::IfExpression { + condition: condition, + block: block_result, + else_: else_, + type_: result_type, + }, + substitution, + )) + } + + fn with_struct_getter( + self: &Self, + ctx: &Context, + substitution: &SubstitutionMap, + struct_getter: &ast::StructGetter, + ) -> Result<(ast::StructGetter, SubstitutionMap)> { + let mut substitution = substitution.clone(); + let (source, subst) = self.with_expression(ctx, &substitution, &struct_getter.source)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + + let field_type = match get_attr(ctx, &source.type_, &struct_getter.attribute)? { + StructAttr::Field(type_) => type_, + StructAttr::Method(type_) => type_, + }; + + substitution = compose_substitutions(ctx, &substitution, &unify(ctx, &struct_getter.type_, &field_type)?)?; + + Ok(( + ast::StructGetter { + source: source, + attribute: struct_getter.attribute.clone(), + type_: apply_substitution(ctx, &substitution, &struct_getter.type_)?, + }, + substitution, + )) + } + + fn with_block_expression(self: &Self, ctx: &Context, substitution: &SubstitutionMap, block: &ast::Block) -> Result<(ast::Block, SubstitutionMap)> { + let mut substitution = substitution.clone(); + let (result, subst) = self.with_block(ctx, &substitution, &block)?; + substitution = compose_substitutions(ctx, &substitution, &subst)?; + Ok((result, substitution)) + } + + fn with_op(self: &Self, ctx: &Context, substitution: &SubstitutionMap, op: &ast::Operation) -> Result<(ast::Operation, SubstitutionMap)> { + let mut substitution = substitution.clone(); + let (expr_left, subst_left) = self.with_expression(ctx, &substitution, &op.left)?; + let (expr_right, subst_right) = self.with_expression(ctx, &substitution, &op.right)?; + substitution = compose_substitutions(ctx, &substitution, &subst_left)?; + substitution = compose_substitutions(ctx, &substitution, &subst_right)?; + Ok(( + ast::Operation { + left: expr_left, + op: op.op.clone(), + right: expr_right, + }, + substitution, + )) + } }