diff --git a/examples/main.bl b/examples/main.bl index 6b86a2f..6a417ed 100644 --- a/examples/main.bl +++ b/examples/main.bl @@ -65,6 +65,10 @@ impl User { } } +fn main(): i64 { + add(4, 4) +} + // type TestTrait trait { // fn classMethod(id: i64): Self; // fn instanceMethod(self: Self): i64; diff --git a/src/interpreter.rs b/src/interpreter.rs new file mode 100644 index 0000000..155b57f --- /dev/null +++ b/src/interpreter.rs @@ -0,0 +1,682 @@ +use crate::ast; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +#[derive(Debug, Clone)] +pub enum NumericValue { + I8(i8), + I16(i16), + I32(i32), + I64(i64), + ISize(isize), + + U8(u8), + U16(u16), + U32(u32), + U64(u64), + USize(usize), + + F32(f32), + F64(f64), +} + +#[derive(Debug, Clone)] +pub struct StructValue { + source: ast::StructTypeDeclaration, + fields: HashMap, +} + +type BuiltinFunction = fn(Vec) -> Value; + +#[derive(Debug, Clone)] +pub enum Function { + User(ast::Function), + Builtin(BuiltinFunction), +} + +#[derive(Debug, Clone)] +pub enum Value { + Numeric(NumericValue), + Function(Function), + Struct(Arc>), + Unit, +} + +#[derive(Debug, Clone)] +pub enum NamedEntity { + TypeDeclaration(ast::TypeDeclaration), + Variable(Value), +} + +#[derive(Debug, Clone)] +struct Context { + pub environment: HashMap, + pub impls: HashMap, + pub current_module: ast::Module, +} + +impl Context { + fn set_variable(&mut self, name: String, value: &Value) { + self.environment + .insert(name.to_string(), NamedEntity::Variable(value.clone())); + } + + fn new_env(&self) -> Context { + return Context::from_module(&self.current_module); + } + + fn from_module(module: &ast::Module) -> Context { + let mut ctx = Context { + environment: create_builtins(), + impls: HashMap::new(), + current_module: module.clone(), + }; + + for item in ctx.current_module.items.iter() { + match item { + ast::ModuleItem::TypeDeclaration(ast::TypeDeclaration::Struct(struct_)) => { + ctx.environment.insert( + struct_.name.name.value.to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct(struct_.clone())), + ); + } + ast::ModuleItem::TypeDeclaration(ast::TypeDeclaration::Alias(alias)) => { + ctx.environment.insert( + alias.name.name.value.to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Alias(alias.clone())), + ); + } + ast::ModuleItem::Function(function) => { + ctx.environment.insert( + function.declaration.name.name.value.to_string(), + NamedEntity::Variable(Value::Function(Function::User(function.clone()))), + ); + } + ast::ModuleItem::Impl(impl_) => { + ctx.impls + .insert(impl_.struct_name.name.value.to_string(), impl_.clone()); + } + _ => {} + } + } + return ctx; + } +} + +fn create_builtins() -> HashMap { + let mut result = HashMap::new(); + result.insert( + "i8".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "i8".to_string(), + }, + )), + ); + result.insert( + "i16".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "i16".to_string(), + }, + )), + ); + result.insert( + "i32".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "i32".to_string(), + }, + )), + ); + result.insert( + "i64".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "i64".to_string(), + }, + )), + ); + result.insert( + "isize".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "isize".to_string(), + }, + )), + ); + + result.insert( + "u8".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "u8".to_string(), + }, + )), + ); + result.insert( + "u16".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "u16".to_string(), + }, + )), + ); + result.insert( + "u32".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "u32".to_string(), + }, + )), + ); + result.insert( + "u64".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "u64".to_string(), + }, + )), + ); + result.insert( + "usize".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "usize".to_string(), + }, + )), + ); + + result.insert( + "f32".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "f32".to_string(), + }, + )), + ); + result.insert( + "f64".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "f64".to_string(), + }, + )), + ); + + result.insert( + "!".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "!".to_string(), + }, + )), + ); + result.insert( + "unit".to_string(), + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Primitive( + ast::PrimitiveTypeDeclaration { + name: "!".to_string(), + }, + )), + ); + + return result; +} + +pub enum ExpressionResult { + Value(Value), + Return(Value), +} + +pub struct TreeWalkInterpreter {} + +impl TreeWalkInterpreter { + pub fn with_module(self: &Self, module: &ast::Module) -> Value { + let mut ctx = Context::from_module(module); + + let main = match &ctx.environment["main"] { + NamedEntity::Variable(Value::Function(Function::User(func))) => func.clone(), + _ => panic!("main should be a user defined function"), + }; + + return self.with_function(&mut ctx, &main); + } + + fn with_function(self: &Self, ctx: &mut Context, function: &ast::Function) -> Value { + let result = self.with_block(ctx, &function.block); + return match result { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => r, + }; + } + + fn with_block(self: &Self, ctx: &mut Context, block: &ast::Block) -> ExpressionResult { + let mut last = ExpressionResult::Value(Value::Unit); + for statement in block.statements.iter() { + let result = self.with_statement(ctx, statement); + match result { + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + ExpressionResult::Value(r) => { + last = ExpressionResult::Value(r); + } + } + } + return last; + } + + fn with_statement( + self: &Self, + ctx: &mut Context, + statement: &ast::Statement, + ) -> ExpressionResult { + match statement { + ast::Statement::Return(return_statement) => { + let result = match self.with_expression(ctx, &return_statement.source) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + return ExpressionResult::Return(result); + } + ast::Statement::Let(let_statement) => { + let result = match self.with_expression(ctx, &let_statement.expression) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + ctx.set_variable(let_statement.variable_name.name.value.to_string(), &result); + return ExpressionResult::Value(Value::Unit); + } + ast::Statement::Assignment(assignment_statement) => { + return self.with_assignment_statement(ctx, assignment_statement); + } + ast::Statement::Expression(expression) => { + return self.with_expression(ctx, expression); + } + } + } + + fn with_assignment_statement( + self: &Self, + ctx: &mut Context, + statement: &ast::AssignmentStatement, + ) -> ExpressionResult { + let result = match self.with_expression(ctx, &statement.expression) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + match &statement.source { + ast::AssignmentTarget::Variable(variable) => { + ctx.set_variable(variable.name.name.value.to_string(), &result); + } + ast::AssignmentTarget::StructAttr(struct_attr) => { + let mut source = match self.with_expression(ctx, &struct_attr.source) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + match &mut source { + Value::Struct(s) => { + let mut struct_ = s.lock().unwrap(); + struct_ + .fields + .insert(struct_attr.attribute.name.value.clone(), result); + } + _ => panic!("set attr on nonstruct, should never happen due to type system"), + } + } + } + return ExpressionResult::Value(Value::Unit); + } + + fn with_expression( + self: &Self, + ctx: &mut Context, + expression: &ast::Expression, + ) -> ExpressionResult { + match &*expression.subexpression { + ast::Subexpression::LiteralInt(literal_int) => { + let value: i64 = literal_int.value.value.parse().unwrap(); + return ExpressionResult::Value(Value::Numeric(NumericValue::I64(value))); + } + ast::Subexpression::LiteralFloat(literal_float) => { + let value: f64 = literal_float.value.value.parse().unwrap(); + return ExpressionResult::Value(Value::Numeric(NumericValue::F64(value))); + } + ast::Subexpression::LiteralStruct(literal_struct) => { + let declaration = match &ctx.environment[&literal_struct.name.name.value] { + NamedEntity::TypeDeclaration(ast::TypeDeclaration::Struct(declaration)) => { + declaration.clone() + } + _ => panic!("not a struct"), + }; + + let mut fields = HashMap::new(); + for field in declaration.fields.iter() { + for (field_name, field_expression) in literal_struct.fields.iter() { + if field.name.name.value == field_name.name.value { + let field_result = match self.with_expression(ctx, field_expression) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + fields.insert(field.name.name.value.to_string(), field_result); + } + } + } + return ExpressionResult::Value(Value::Struct(Arc::new(Mutex::new(StructValue { + source: declaration.clone(), + fields: fields, + })))); + } + ast::Subexpression::FunctionCall(function_call) => { + let source = match self.with_expression(ctx, &function_call.source) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + let mut argument_values = vec![]; + for arg in function_call.arguments.iter() { + let argument_value = match self.with_expression(ctx, arg) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + argument_values.push(argument_value); + } + match &source { + Value::Function(Function::User(user_function)) => { + let mut fn_ctx = ctx.new_env(); + for (i, function_arg) in + user_function.declaration.arguments.iter().enumerate() + { + fn_ctx.set_variable( + function_arg.name.name.value.to_string(), + &argument_values[i].clone(), + ); + } + return ExpressionResult::Value( + self.with_function(&mut fn_ctx, user_function), + ); + } + Value::Function(Function::Builtin(builtin_function)) => { + return ExpressionResult::Value(builtin_function(argument_values)); + } + _ => panic!("type error: function call source must be a function"), + } + } + ast::Subexpression::VariableUsage(variable_usage) => { + let variable_value = match &ctx.environment[&variable_usage.name.name.value] { + NamedEntity::Variable(v) => v.clone(), + _ => panic!("variable lookup of type"), + }; + return ExpressionResult::Value(variable_value); + } + ast::Subexpression::StructGetter(struct_getter) => { + let source = match self.with_expression(ctx, &struct_getter.source) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + match &source { + Value::Struct(struct_) => { + let s = struct_.lock().unwrap(); + return ExpressionResult::Value( + s.fields[&struct_getter.attribute.name.value].clone(), + ); + } + _ => { + panic!("TypeError: struct getter used with non-struct"); + } + } + } + ast::Subexpression::Block(block) => { + return self.with_block(ctx, block); + } + ast::Subexpression::Op(op) => { + let left = match self.with_expression(ctx, &op.left) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + let right = match self.with_expression(ctx, &op.right) { + ExpressionResult::Value(r) => r, + ExpressionResult::Return(r) => { + return ExpressionResult::Return(r); + } + }; + let result = match (&left, &op.op, &right) { + //I + ( + Value::Numeric(NumericValue::I8(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::I8(r)), + ) => Value::Numeric(NumericValue::I8(l + r)), + ( + Value::Numeric(NumericValue::I8(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::I8(r)), + ) => Value::Numeric(NumericValue::I8(l - r)), + ( + Value::Numeric(NumericValue::I8(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::I8(r)), + ) => Value::Numeric(NumericValue::I8(l * r)), + ( + Value::Numeric(NumericValue::I8(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::I8(r)), + ) => Value::Numeric(NumericValue::I8(l / r)), + + ( + Value::Numeric(NumericValue::I16(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::I16(r)), + ) => Value::Numeric(NumericValue::I16(l + r)), + ( + Value::Numeric(NumericValue::I16(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::I16(r)), + ) => Value::Numeric(NumericValue::I16(l - r)), + ( + Value::Numeric(NumericValue::I16(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::I16(r)), + ) => Value::Numeric(NumericValue::I16(l * r)), + ( + Value::Numeric(NumericValue::I16(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::I16(r)), + ) => Value::Numeric(NumericValue::I16(l / r)), + + ( + Value::Numeric(NumericValue::I32(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::I32(r)), + ) => Value::Numeric(NumericValue::I32(l + r)), + ( + Value::Numeric(NumericValue::I32(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::I32(r)), + ) => Value::Numeric(NumericValue::I32(l - r)), + ( + Value::Numeric(NumericValue::I32(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::I32(r)), + ) => Value::Numeric(NumericValue::I32(l * r)), + ( + Value::Numeric(NumericValue::I32(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::I32(r)), + ) => Value::Numeric(NumericValue::I32(l / r)), + + ( + Value::Numeric(NumericValue::I64(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::I64(r)), + ) => Value::Numeric(NumericValue::I64(l + r)), + ( + Value::Numeric(NumericValue::I64(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::I64(r)), + ) => Value::Numeric(NumericValue::I64(l - r)), + ( + Value::Numeric(NumericValue::I64(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::I64(r)), + ) => Value::Numeric(NumericValue::I64(l * r)), + ( + Value::Numeric(NumericValue::I64(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::I64(r)), + ) => Value::Numeric(NumericValue::I64(l / r)), + + //U + ( + Value::Numeric(NumericValue::U8(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::U8(r)), + ) => Value::Numeric(NumericValue::U8(l + r)), + ( + Value::Numeric(NumericValue::U8(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::U8(r)), + ) => Value::Numeric(NumericValue::U8(l - r)), + ( + Value::Numeric(NumericValue::U8(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::U8(r)), + ) => Value::Numeric(NumericValue::U8(l * r)), + ( + Value::Numeric(NumericValue::U8(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::U8(r)), + ) => Value::Numeric(NumericValue::U8(l / r)), + + ( + Value::Numeric(NumericValue::U16(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::U16(r)), + ) => Value::Numeric(NumericValue::U16(l + r)), + ( + Value::Numeric(NumericValue::U16(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::U16(r)), + ) => Value::Numeric(NumericValue::U16(l - r)), + ( + Value::Numeric(NumericValue::U16(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::U16(r)), + ) => Value::Numeric(NumericValue::U16(l * r)), + ( + Value::Numeric(NumericValue::U16(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::U16(r)), + ) => Value::Numeric(NumericValue::U16(l / r)), + + ( + Value::Numeric(NumericValue::U32(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::U32(r)), + ) => Value::Numeric(NumericValue::U32(l + r)), + ( + Value::Numeric(NumericValue::U32(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::U32(r)), + ) => Value::Numeric(NumericValue::U32(l - r)), + ( + Value::Numeric(NumericValue::U32(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::U32(r)), + ) => Value::Numeric(NumericValue::U32(l * r)), + ( + Value::Numeric(NumericValue::U32(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::U32(r)), + ) => Value::Numeric(NumericValue::U32(l / r)), + + ( + Value::Numeric(NumericValue::U64(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::U64(r)), + ) => Value::Numeric(NumericValue::U64(l + r)), + ( + Value::Numeric(NumericValue::U64(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::U64(r)), + ) => Value::Numeric(NumericValue::U64(l - r)), + ( + Value::Numeric(NumericValue::U64(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::U64(r)), + ) => Value::Numeric(NumericValue::U64(l * r)), + ( + Value::Numeric(NumericValue::U64(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::U64(r)), + ) => Value::Numeric(NumericValue::U64(l / r)), + + //F + ( + Value::Numeric(NumericValue::F32(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::F32(r)), + ) => Value::Numeric(NumericValue::F32(l + r)), + ( + Value::Numeric(NumericValue::F32(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::F32(r)), + ) => Value::Numeric(NumericValue::F32(l - r)), + ( + Value::Numeric(NumericValue::F32(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::F32(r)), + ) => Value::Numeric(NumericValue::F32(l * r)), + ( + Value::Numeric(NumericValue::F32(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::F32(r)), + ) => Value::Numeric(NumericValue::F32(l / r)), + + ( + Value::Numeric(NumericValue::F64(l)), + ast::Operator::Plus, + Value::Numeric(NumericValue::F64(r)), + ) => Value::Numeric(NumericValue::F64(l + r)), + ( + Value::Numeric(NumericValue::F64(l)), + ast::Operator::Minus, + Value::Numeric(NumericValue::F64(r)), + ) => Value::Numeric(NumericValue::F64(l - r)), + ( + Value::Numeric(NumericValue::F64(l)), + ast::Operator::Mul, + Value::Numeric(NumericValue::F64(r)), + ) => Value::Numeric(NumericValue::F64(l * r)), + ( + Value::Numeric(NumericValue::F64(l)), + ast::Operator::Div, + Value::Numeric(NumericValue::F64(r)), + ) => Value::Numeric(NumericValue::F64(l / r)), + + //fail + _ => panic!(""), + }; + return ExpressionResult::Value(result); + } + } + } +} diff --git a/src/main.rs b/src/main.rs index ec05ba9..9092c7e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ // mod types; mod ast; mod errors; +mod interpreter; mod type_alias_resolution; mod type_checking; #[macro_use] @@ -64,6 +65,9 @@ fn main() { Ok((checked_ast, subst)) => { println!("checked ast: {:#?}", &checked_ast); println!("substitutions: {:#?}", &subst); + let interpreter = interpreter::TreeWalkInterpreter {}; + let result = interpreter.with_module(&checked_ast); + println!("final result: {:#?}", &result); } Err(err) => { println!("type checking error: {:#?}", &err);