From b2709ffc8203fa773f1715fbbc7d825fe60e2333 Mon Sep 17 00:00:00 2001 From: Andrew Segavac Date: Mon, 25 Aug 2025 21:51:50 -0600 Subject: [PATCH] got type system working --- examples/strings.bl | 11 +- packages/boringlang/src/commands/run.ts | 11 +- packages/boringlang/src/parse/ast.ts | 73 ++++ packages/boringlang/src/parse/semantics.ts | 35 +- packages/boringlang/src/types/builtins.ts | 21 + packages/boringlang/src/types/context.ts | 86 ++++ .../boringlang/src/types/trait_checker.ts | 10 +- .../src/types/type_alias_resolution.ts | 238 ++++++++++ packages/boringlang/src/types/type_checker.ts | 411 +++++++++++++++--- packages/boringlang/src/types/type_system.ts | 53 ++- 10 files changed, 858 insertions(+), 91 deletions(-) create mode 100644 packages/boringlang/src/types/builtins.ts create mode 100644 packages/boringlang/src/types/context.ts create mode 100644 packages/boringlang/src/types/type_alias_resolution.ts diff --git a/examples/strings.bl b/examples/strings.bl index 61e863c..0397bda 100644 --- a/examples/strings.bl +++ b/examples/strings.bl @@ -1,11 +1,6 @@ fn main(): String { - let a = 2; - a; - a = 3; - a = if(true) {"asdf"} else {"fdsa"}; - a.b.c.d(); - a = (b + c.d()); // comment - return a.b() ; + let a = "asdf"; + return a ; } type User struct { @@ -23,7 +18,7 @@ impl TestTrait for User { return Self{id: id}; } fn instance_method(self: Self): i64 { - return self.get_id(); + return self.id; } fn default_impl(self: Self): i64 { return self.instance_method(); diff --git a/packages/boringlang/src/commands/run.ts b/packages/boringlang/src/commands/run.ts index 3c62900..8907841 100644 --- a/packages/boringlang/src/commands/run.ts +++ b/packages/boringlang/src/commands/run.ts @@ -2,6 +2,9 @@ import { defineCommand } from "@bunli/core"; import { boringGrammar } from "../parse/grammar"; import { semantics } from "../parse/semantics"; import TraitChecker from "../types/trait_checker"; +import { TypeAliasResolver } from "../types/type_alias_resolution"; +import { TypeSystem } from "../types/type_system"; +import { TypeChecker } from "../types/type_checker"; export const run = defineCommand({ name: "run", @@ -22,8 +25,14 @@ export const run = defineCommand({ const adapter = semantics(match); const ast = adapter.toAST(); new TraitChecker().withModule(ast); + const aliasResolvedAst = new TypeAliasResolver().withModule(ast); + const typeSystem = new TypeSystem(); + const typeChecker = new TypeChecker(); + typeChecker.withModule(aliasResolvedAst, typeSystem); + typeSystem.solve(); + console.log(JSON.stringify(typeSystem, null, 2)); - console.log(JSON.stringify(ast, null, 2)); + // console.log(JSON.stringify(aliasResolvedAst, null, 2)); } else { console.log(match.message); // console.log(boringGrammar.trace(text, "Module").toString()); diff --git a/packages/boringlang/src/parse/ast.ts b/packages/boringlang/src/parse/ast.ts index a8724ed..a9beb40 100644 --- a/packages/boringlang/src/parse/ast.ts +++ b/packages/boringlang/src/parse/ast.ts @@ -62,6 +62,7 @@ export interface Operation { left: Expression; op: "+" | "-" | "*" | "/"; right: Expression; + type: TypeUsage; } export interface VariableUsage { @@ -116,6 +117,7 @@ export interface AssignmentStatement { export type Statement = ReturnStatement | LetStatement | AssignmentStatement | Expression; export interface Block { + expressionType: "Block"; statements: Statement[]; type: TypeUsage; } @@ -196,3 +198,74 @@ export interface UnknownTypeUsage { } export type TypeUsage = NamedTypeUsage | FunctionTypeUsage | UnknownTypeUsage; + +export const newVoid: () => TypeUsage = () => { + return { + typeUsage: "NamedTypeUsage", + name: { text: "Void", spanStart: 0, spanEnd: 0 }, + }; +}; + +export const newNever: () => TypeUsage = () => { + return { + typeUsage: "NamedTypeUsage", + name: { text: "Never", spanStart: 0, spanEnd: 0 }, + }; +}; + +function containsReturnExpression(expression: Expression) { + if (expression.subExpression.expressionType === "LiteralStruct") { + for (const field of expression.subExpression.fields) { + if (containsReturnExpression(field.expression)) { + return true; + } + } + } + if (expression.subExpression.expressionType === "IfExpression") { + if (containsReturn(expression.subExpression.block)) { + return true; + } + if (expression.subExpression.else && containsReturn(expression.subExpression.else)) { + return true; + } + } + if (expression.subExpression.expressionType === "Block") { + if (containsReturn(expression.subExpression)) { + return true; + } + } + if (expression.subExpression.expressionType === "Operation") { + if (containsReturnExpression(expression.subExpression.left)) { + return true; + } + if (containsReturnExpression(expression.subExpression.right)) { + return true; + } + } + + return false; +} + +export function containsReturn(block: Block) { + for (const statement of block.statements) { + if (statement.statementType === "ReturnStatement") { + return true; + } + if (statement.statementType === "AssignmentStatement") { + if (containsReturnExpression(statement.expression)) { + return true; + } + } + if (statement.statementType === "LetStatement") { + if (containsReturnExpression(statement.expression)) { + return true; + } + } + if (statement.statementType === "Expression") { + if (containsReturnExpression(statement)) { + return true; + } + } + } + return false; +} diff --git a/packages/boringlang/src/parse/semantics.ts b/packages/boringlang/src/parse/semantics.ts index 79d4a13..42987b6 100644 --- a/packages/boringlang/src/parse/semantics.ts +++ b/packages/boringlang/src/parse/semantics.ts @@ -1,6 +1,7 @@ import { AssignmentStatement, Block, + containsReturn, Expression, Function, FunctionArgument, @@ -19,6 +20,7 @@ import { Module, ModuleItem, NamedTypeUsage, + newNever, Operation, ReturnStatement, Statement, @@ -43,8 +45,6 @@ function nextUnknown() { export const semantics = boringGrammar.createSemantics(); semantics.addOperation("toAST", { LiteralInt(a): LiteralInt { - console.log(this); - console.log(a.source.startIdx); return { expressionType: "LiteralInt", value: this.sourceString, @@ -106,11 +106,15 @@ semantics.addOperation("toAST", { }; }, FunctionCall(expression, _2, args, _4): FunctionCall { + const resolvedArgs = args.asIteration().children.map((c) => c.toAST()); return { expressionType: "FunctionCall", source: expression.toAST(), - arguments: args.asIteration().children.map((c) => c.toAST()), - type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, + arguments: resolvedArgs, + type: { + typeUsage: "UnknownTypeUsage", + name: nextUnknown(), + }, }; }, StructGetter(expression, _2, identifier): StructGetter { @@ -148,7 +152,11 @@ semantics.addOperation("toAST", { return factor.toAST(); }, Expression(expression): Expression { - return expression.toAST(); + return { + statementType: "Expression", + subExpression: expression.toAST(), + type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, + }; }, Expression_plus(expression, _2, factor): Operation { return { @@ -156,6 +164,7 @@ semantics.addOperation("toAST", { left: expression.toAST(), op: "+", right: factor.toAST(), + type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, }; }, Expression_minus(expression, _2, factor): Operation { @@ -164,6 +173,7 @@ semantics.addOperation("toAST", { left: expression.toAST(), op: "-", right: factor.toAST(), + type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, }; }, Factor_mult(factor, _2, term): Operation { @@ -172,6 +182,7 @@ semantics.addOperation("toAST", { left: factor.toAST(), op: "*", right: term.toAST(), + type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, }; }, Factor_div(factor, _2, term): Operation { @@ -180,6 +191,7 @@ semantics.addOperation("toAST", { left: factor.toAST(), op: "/", right: term.toAST(), + type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, }; }, Statement(statement): Statement { @@ -217,11 +229,18 @@ semantics.addOperation("toAST", { Block(_1, statements, expression, _4): Block { const lines = statements.asIteration().children.map((c) => c.toAST()); const finalExpression = expression.toAST(); - lines.push(finalExpression.length > 0 ? finalExpression[0] : null); - return { + if (finalExpression.length > 0) { + lines.push(finalExpression[0]); + } + const block: Block = { + expressionType: "Block", statements: lines, - type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, + type: newNever(), }; + if (!containsReturn(block)) { + block.type = { typeUsage: "UnknownTypeUsage", name: nextUnknown() }; + } + return block; }, NamedTypeUsage(name): NamedTypeUsage { return { diff --git a/packages/boringlang/src/types/builtins.ts b/packages/boringlang/src/types/builtins.ts new file mode 100644 index 0000000..2a2133e --- /dev/null +++ b/packages/boringlang/src/types/builtins.ts @@ -0,0 +1,21 @@ +import { Context } from "./context"; + +export function newContext(): Context { + const result: Context = { + currentFunctionReturn: null, + environment: {}, + }; + result.environment["i8"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["i16"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["i32"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["i64"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["f8"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["f16"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["f32"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["f64"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["String"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["Void"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + result.environment["Never"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] }; + + return result; +} diff --git a/packages/boringlang/src/types/context.ts b/packages/boringlang/src/types/context.ts new file mode 100644 index 0000000..fe3a15f --- /dev/null +++ b/packages/boringlang/src/types/context.ts @@ -0,0 +1,86 @@ +import { NamedTypeUsage, TypeUsage } from "../parse/ast"; + +interface EnvImpl { + trait: string | null; + functions: Record; +} + +interface NamedType { + namedEntity: "NamedType"; + isA: "Scalar" | "Trait" | "Struct"; + fields: Record; + impls: EnvImpl[]; +} + +interface Variable { + namedEntity: "Variable"; + type: TypeUsage; +} + +type NamedEntity = NamedType | Variable; + +export interface Context { + currentFunctionReturn: TypeUsage | null; + environment: Record; +} + +export function getAttr(ctx: Context, name: string, field: string) { + const struct = ctx.environment[name]; + if (!struct || struct.namedEntity !== "NamedType") { + throw Error(`Unknown type ${name}`); + } + if (struct.fields[field]) { + return struct.fields[field]; + } + let results: TypeUsage[] = []; + for (const impl of struct.impls) { + if (impl.functions[field]) { + results.push(impl.functions[field]); + } + } + if (results.length === 0) { + console.log(JSON.stringify(struct, null, 2)); + throw Error(`${name} has no attribue ${field}`); + } + if (results.length > 1) { + throw Error(`${name} has multiple attribues ${field}, use universal function call syntax`); + } + return results[0]; +} + +export function typeExists(ctx: Context, type: TypeUsage) { + if (type.typeUsage === "NamedTypeUsage") { + if ( + !ctx.environment[type.name.text] || + ctx.environment[type.name.text].namedEntity !== "NamedType" + ) { + throw Error(`${type.name.text} is not a type.`); + } + } + if (type.typeUsage === "FunctionTypeUsage") { + for (const arg of type.arguments) { + typeExists(ctx, arg); + } + typeExists(ctx, type.returnType); + } +} + +export function replaceType(oldName: string, newType: TypeUsage, inType: TypeUsage) { + if (inType.typeUsage === "NamedTypeUsage") { + if (inType.name.text === oldName) { + return deepCopy(newType); + } + } + if (inType.typeUsage === "FunctionTypeUsage") { + const result = deepCopy(inType); + for (const [i, arg] of inType.arguments.entries()) { + result.arguments[i] = replaceType(oldName, newType, arg); + } + result.returnType = replaceType(oldName, newType, inType.returnType); + } + return deepCopy(inType); +} + +export const deepCopy = (o: T) => { + return JSON.parse(JSON.stringify(o)) as T; +}; diff --git a/packages/boringlang/src/types/trait_checker.ts b/packages/boringlang/src/types/trait_checker.ts index 7692d04..ad5d1aa 100644 --- a/packages/boringlang/src/types/trait_checker.ts +++ b/packages/boringlang/src/types/trait_checker.ts @@ -10,7 +10,7 @@ export default class TraitChecker { let ctx: Context = { environment: {} }; for (const item of module.items) { if (item.moduleItem == "TraitTypeDeclaration") { - ctx.environment[item.name.name] = item; + ctx.environment[item.name.text] = item; } } for (const item of module.items) { @@ -21,20 +21,20 @@ export default class TraitChecker { }; withImpl = (ctx: Context, impl: Impl) => { if (new Set(impl.functions.map((fn) => fn.declaration.name)).size !== impl.functions.length) { - throw Error(`Duplicate functions in ${impl.struct.name.name}`); + throw Error(`Duplicate functions in ${impl.struct.name.text}`); } if (impl.trait == null) { return; } - const trait = ctx.environment[impl.trait.name.name]; + const trait = ctx.environment[impl.trait.name.text]; if (!trait) { throw Error(`No such trait: ${impl.trait.name}`); } if (impl.functions.length !== trait.functions.length) { - throw Error(`Mismatched impl/trait len ${impl.trait.name.name} for ${impl.struct.name.name}`); + throw Error(`Mismatched impl/trait len ${impl.trait.name.text} for ${impl.struct.name.text}`); } for (let i = 0; i < impl.functions.length; i++) { - if (impl.functions[i].declaration.name.name !== trait.functions[i].name.name) { + if (impl.functions[i].declaration.name.text !== trait.functions[i].name.text) { throw Error( `Mismatched impl/trait names ${impl.functions[i].declaration.name} for ${trait.functions[i].name}`, ); diff --git a/packages/boringlang/src/types/type_alias_resolution.ts b/packages/boringlang/src/types/type_alias_resolution.ts new file mode 100644 index 0000000..8a66905 --- /dev/null +++ b/packages/boringlang/src/types/type_alias_resolution.ts @@ -0,0 +1,238 @@ +import { + AssignmentStatement, + Block, + Expression, + Function, + FunctionCall, + FunctionDeclaration, + IfExpression, + Impl, + LetStatement, + LiteralStruct, + Module, + Operation, + ReturnStatement, + StructGetter, + StructTypeDeclaration, + TraitTypeDeclaration, + TypeUsage, + VariableUsage, +} from "../parse/ast"; +import { deepCopy, replaceType } from "./context"; + +interface AliasContext { + environment: Record; +} + +export class TypeAliasResolver { + withModule = (module: Module) => { + const ctx: AliasContext = { environment: {} }; + + const result = deepCopy(module); + for (const [i, item] of module.items.entries()) { + if (item.moduleItem === "TraitTypeDeclaration") { + let traitCtx = deepCopy(ctx); + traitCtx.environment["Self"] = { typeUsage: "NamedTypeUsage", name: item.name }; + result.items[i] = this.withTraitTypeDeclaration(traitCtx, item); + } + if (item.moduleItem === "Impl") { + let implCtx = deepCopy(ctx); + implCtx.environment["Self"] = { typeUsage: "NamedTypeUsage", name: item.struct.name }; + result.items[i] = this.withImpl(implCtx, item); + } + if (item.moduleItem === "StructTypeDeclaration") { + let structCtx = deepCopy(ctx); + structCtx.environment["Self"] = { typeUsage: "NamedTypeUsage", name: item.name }; + result.items[i] = this.withStructDeclaration(structCtx, item); + } + } + return result; + }; + + withTraitTypeDeclaration = (ctx: AliasContext, trait: TraitTypeDeclaration) => { + const result = deepCopy(trait); + for (const [oldName, newType] of Object.entries(ctx.environment)) { + for (const [i, fn] of trait.functions.entries()) { + result.functions[i] = this.withFunctionDeclaration(ctx, fn); + } + } + return result; + }; + + withImpl = (ctx: AliasContext, impl: Impl) => { + const result = deepCopy(impl); + for (const [oldName, newType] of Object.entries(ctx.environment)) { + for (const [i, fn] of impl.functions.entries()) { + result.functions[i] = this.withFunction(ctx, fn); + } + } + return result; + }; + + withStructDeclaration = (ctx: AliasContext, struct: StructTypeDeclaration) => { + const result = deepCopy(struct); + for (const [oldName, newType] of Object.entries(ctx.environment)) { + for (const [i, field] of struct.fields.entries()) { + result.fields[i].type = replaceType(oldName, newType, field.type); + } + } + return result; + }; + + withFunctionDeclaration = (ctx: AliasContext, fn: FunctionDeclaration) => { + const result = deepCopy(fn); + for (const [oldName, newType] of Object.entries(ctx.environment)) { + for (const [i, arg] of fn.arguments.entries()) { + result.arguments[i].type = replaceType(oldName, newType, arg.type); + } + result.returnType = replaceType(oldName, newType, result.returnType); + } + return result; + }; + + withFunction = (ctx: AliasContext, fn: Function) => { + const result = deepCopy(fn); + result.declaration = this.withFunctionDeclaration(ctx, fn.declaration); + result.block = this.withBlock(ctx, fn.block); + return result; + }; + + withBlock = (ctx: AliasContext, block: Block) => { + const result = deepCopy(block); + for (const [i, statement] of block.statements.entries()) { + if (statement.statementType === "AssignmentStatement") { + result.statements[i] = this.withAssignmentStatement(ctx, statement); + } + if (statement.statementType === "LetStatement") { + result.statements[i] = this.withLetStatement(ctx, statement); + } + if (statement.statementType === "Expression") { + result.statements[i] = this.withExpression(ctx, statement); + } + if (statement.statementType === "ReturnStatement") { + result.statements[i] = this.withReturnStatement(ctx, statement); + } + for (const [oldName, newType] of Object.entries(ctx.environment)) { + result.type = replaceType(oldName, newType, block.type); + } + } + return result; + }; + + withAssignmentStatement = (ctx: AliasContext, statement: AssignmentStatement) => { + const result = deepCopy(statement); + if (statement.source.expressionType == "StructGetter") { + result.source = this.withStructGetter(ctx, statement.source); + } + if (statement.source.expressionType == "VariableUsage") { + result.source = this.withVariableUsage(ctx, statement.source); + } + result.expression = this.withExpression(ctx, statement.expression); + return result; + }; + + withLetStatement = (ctx: AliasContext, statement: LetStatement) => { + const result = deepCopy(statement); + result.expression = this.withExpression(ctx, statement.expression); + return result; + }; + + withReturnStatement = (ctx: AliasContext, statement: ReturnStatement) => { + const result = deepCopy(statement); + result.source = this.withExpression(ctx, statement.source); + return result; + }; + + withExpression = (ctx: AliasContext, expression: Expression) => { + const result = deepCopy(expression); + if (expression.subExpression.expressionType === "LiteralStruct") { + result.subExpression = this.withLiteralStruct(ctx, expression.subExpression); + } + if (expression.subExpression.expressionType === "FunctionCall") { + result.subExpression = this.withFunctionCall(ctx, expression.subExpression); + } + if (expression.subExpression.expressionType === "VariableUsage") { + result.subExpression = this.withVariableUsage(ctx, expression.subExpression); + } + if (expression.subExpression.expressionType === "IfExpression") { + result.subExpression = this.withIfExpression(ctx, expression.subExpression); + } + if (expression.subExpression.expressionType === "StructGetter") { + result.subExpression = this.withStructGetter(ctx, expression.subExpression); + } + if (expression.subExpression.expressionType === "Block") { + result.subExpression = this.withBlock(ctx, expression.subExpression); + } + if (expression.subExpression.expressionType === "Operation") { + result.subExpression = this.withOperation(ctx, expression.subExpression); + } + return result; + }; + + withLiteralStruct = (ctx: AliasContext, literalStruct: LiteralStruct) => { + const result = deepCopy(literalStruct); + for (const [i, field] of literalStruct.fields.entries()) { + result.fields[i].expression = this.withExpression(ctx, field.expression); + } + + for (const [oldName, newType] of Object.entries(ctx.environment)) { + result.type = replaceType(oldName, newType, literalStruct.type); + if (result.name.text === oldName && newType.typeUsage === "NamedTypeUsage") { + result.name = newType.name; + } + } + return result; + }; + + withFunctionCall = (ctx: AliasContext, fnCall: FunctionCall) => { + const result = deepCopy(fnCall); + for (const [i, arg] of fnCall.arguments.entries()) { + result.arguments[i] = this.withExpression(ctx, arg); + } + result.source = this.withExpression(ctx, fnCall.source); + for (const [oldName, newType] of Object.entries(ctx.environment)) { + result.type = replaceType(oldName, newType, fnCall.type); + } + return result; + }; + + withVariableUsage = (ctx: AliasContext, variableUsage: VariableUsage) => { + const result = deepCopy(variableUsage); + for (const [oldName, newType] of Object.entries(ctx.environment)) { + result.type = replaceType(oldName, newType, variableUsage.type); + } + return result; + }; + + withIfExpression = (ctx: AliasContext, ifExpression: IfExpression) => { + const result = deepCopy(ifExpression); + result.condition = this.withExpression(ctx, ifExpression.condition); + result.block = this.withBlock(ctx, ifExpression.block); + if (ifExpression.else) { + result.else = this.withBlock(ctx, ifExpression.else); + } + for (const [oldName, newType] of Object.entries(ctx.environment)) { + result.type = replaceType(oldName, newType, ifExpression.type); + } + return result; + }; + + withStructGetter = (ctx: AliasContext, structGetter: StructGetter) => { + const result = deepCopy(structGetter); + result.source = this.withExpression(ctx, structGetter.source); + for (const [oldName, newType] of Object.entries(ctx.environment)) { + result.type = replaceType(oldName, newType, structGetter.type); + } + return result; + }; + + withOperation = (ctx: AliasContext, op: Operation) => { + const result = deepCopy(op); + result.left = this.withExpression(ctx, op.left); + result.right = this.withExpression(ctx, op.right); + for (const [oldName, newType] of Object.entries(ctx.environment)) { + result.type = replaceType(oldName, newType, op.type); + } + return result; + }; +} diff --git a/packages/boringlang/src/types/type_checker.ts b/packages/boringlang/src/types/type_checker.ts index d7b3972..2be5323 100644 --- a/packages/boringlang/src/types/type_checker.ts +++ b/packages/boringlang/src/types/type_checker.ts @@ -1,66 +1,33 @@ import { + AssignmentStatement, + Block, + containsReturn, + Expression, Function, + FunctionCall, FunctionDeclaration, functionToType, + IfExpression, Impl, + LetStatement, + LiteralStruct, Module, + newVoid, + Operation, + ReturnStatement, + StructGetter, StructTypeDeclaration, TraitTypeDeclaration, TypeUsage, + VariableUsage, } from "../parse/ast"; +import { newContext } from "./builtins"; +import { Context, deepCopy, typeExists } from "./context"; import { TypeSystem } from "./type_system"; -interface EnvImpl { - trait: string | null; - functions: Record; -} - -interface NamedType { - namedEntity: "NamedType"; - isA: "Scalar" | "Trait" | "Struct"; - fields: Record; - impls: EnvImpl[]; -} - -interface Variable { - namedEntity: "Variable"; - type: TypeUsage; -} - -type NamedEntity = NamedType | Variable; - -interface Context { - currentFunctionReturn: TypeUsage | null; - environment: Record; -} - -function typeExists(ctx: Context, type: TypeUsage) { - if (type.typeUsage === "NamedTypeUsage") { - if ( - !ctx.environment[type.name.text] || - ctx.environment[type.name.text].namedEntity !== "NamedType" - ) { - throw Error(`${type.name.text} is not a type.`); - } - } - if (type.typeUsage === "FunctionTypeUsage") { - for (const arg of type.arguments) { - typeExists(ctx, arg); - } - typeExists(ctx, type.returnType); - } -} - -const deepCopy = (o: T) => { - return JSON.parse(JSON.stringify(o)) as T; -}; - -class TypeChecker { +export class TypeChecker { withModule = (module: Module, typeSystem: TypeSystem) => { - const ctx: Context = { - currentFunctionReturn: null, - environment: {}, - }; + const ctx: Context = newContext(); // add functions, structs, and traits to the context for (const item of module.items) { if (item.moduleItem === "StructTypeDeclaration") { @@ -70,7 +37,7 @@ class TypeChecker { ctx.environment[item.name.text] = { namedEntity: "NamedType", isA: "Struct", - fields: Object.fromEntries(item.fields.map((field) => [field.name, field.type])), + fields: Object.fromEntries(item.fields.map((field) => [field.name.text, field.type])), impls: [], }; } @@ -78,6 +45,18 @@ class TypeChecker { if (ctx.environment[item.name.text]) { throw Error("Duplicate name of trait"); } + const functions: Record = {}; + for (const fn of item.functions) { + if ( + fn.arguments.length && + fn.arguments[0].type.typeUsage == "NamedTypeUsage" && + fn.arguments[0].type.name.text === item.name.text + ) { + const fnCopy = deepCopy(fn); + fnCopy.arguments = fnCopy.arguments.slice(1); + functions[fn.name.text] = functionToType(fnCopy); + } + } ctx.environment[item.name.text] = { namedEntity: "NamedType", isA: "Trait", @@ -85,11 +64,7 @@ class TypeChecker { impls: [ { trait: item.name.text, - functions: Object.fromEntries( - item.functions.map((fn) => { - return [fn.name, functionToType(fn)]; - }), - ), + functions: functions, }, ], }; @@ -111,16 +86,25 @@ class TypeChecker { if (!struct || struct.namedEntity !== "NamedType" || struct.isA !== "Struct") { throw Error("Impl for non-struct"); } + const functions: Record = {}; + for (const fn of item.functions) { + if ( + fn.declaration.arguments.length && + fn.declaration.arguments[0].type.typeUsage == "NamedTypeUsage" && + fn.declaration.arguments[0].type.name.text === item.struct.name.text + ) { + const fnCopy = deepCopy(fn.declaration); + fnCopy.arguments = fnCopy.arguments.slice(1); + functions[fn.declaration.name.text] = functionToType(fnCopy); + } + } struct.impls.push({ trait: item.trait?.name.text ?? null, - functions: Object.fromEntries( - item.functions.map((fn) => { - return [fn.declaration.name, functionToType(fn.declaration)]; - }), - ), + functions: functions, }); } } + typeSystem.context = deepCopy(ctx); // environment set up, actually recurse. for (const item of module.items) { if (item.moduleItem === "Function") { @@ -130,7 +114,7 @@ class TypeChecker { this.withImpl(ctx, item, typeSystem); } if (item.moduleItem === "StructTypeDeclaration") { - this.withStruct(ctx, item, typeSystem); + this.withStructDeclaration(ctx, item, typeSystem); } if (item.moduleItem === "TraitTypeDeclaration") { this.withTrait(ctx, item, typeSystem); @@ -141,13 +125,14 @@ class TypeChecker { withFunction = (ctx: Context, fn: Function, typeSystem: TypeSystem) => { this.withFunctionDeclaration(ctx, fn.declaration, typeSystem); const fnCtx = deepCopy(ctx); + fnCtx.currentFunctionReturn = fn.declaration.returnType; for (const arg of fn.declaration.arguments) { fnCtx.environment[arg.name.text] = { namedEntity: "Variable", type: arg.type, }; } - // this.withBlock(fnCtx, fn.block, typeSystem); + this.withBlock(fnCtx, fn.block, typeSystem); typeSystem.compare({ left: fn.declaration.returnType, operation: { operation: "equals" }, @@ -162,8 +147,302 @@ class TypeChecker { typeExists(ctx, def.returnType); }; - withImpl = (ctx: Context, impl: Impl, typeSystem: TypeSystem) => {}; + withImpl = (ctx: Context, impl: Impl, typeSystem: TypeSystem) => { + for (const fn of impl.functions) { + this.withFunction(ctx, fn, typeSystem); + } + }; - withStruct = (ctx: Context, struct: StructTypeDeclaration, typeSystem: TypeSystem) => {}; - withTrait = (ctx: Context, trait: TraitTypeDeclaration, typeSystem: TypeSystem) => {}; + withStructDeclaration = (ctx: Context, struct: StructTypeDeclaration, typeSystem: TypeSystem) => { + for (const field of struct.fields) { + typeExists(ctx, field.type); + } + }; + + withTrait = (ctx: Context, trait: TraitTypeDeclaration, typeSystem: TypeSystem) => { + for (const method of trait.functions) { + this.withFunctionDeclaration(ctx, method, typeSystem); + } + }; + + withBlock = (ctx: Context, block: Block, typeSystem: TypeSystem) => { + const blockCtx = deepCopy(ctx); + for (const statement of block.statements) { + if (statement.statementType === "AssignmentStatement") { + this.withAssignmentStatement(blockCtx, statement, typeSystem); + } + if (statement.statementType === "LetStatement") { + this.withLetStatement(blockCtx, statement, typeSystem); + blockCtx.environment[statement.variableName.text] = { + namedEntity: "Variable", + type: statement.type, + }; + } + if (statement.statementType === "Expression") { + this.withExpression(blockCtx, statement, typeSystem); + } + if (statement.statementType === "ReturnStatement") { + this.withReturnStatement(blockCtx, statement, typeSystem); + } + if (!containsReturn(block)) { + const lastStatement = block.statements[block.statements.length - 1] ?? null; + if (lastStatement && lastStatement.statementType == "Expression") { + typeSystem.compare({ + left: block.type, + operation: { operation: "equals" }, + right: lastStatement.type, + }); + } else { + typeSystem.compare({ + left: block.type, + operation: { operation: "equals" }, + right: newVoid(), + }); + } + } + } + }; + + withAssignmentStatement = ( + ctx: Context, + statement: AssignmentStatement, + typeSystem: TypeSystem, + ) => { + if (statement.source.expressionType == "StructGetter") { + this.withStructGetter(ctx, statement.source, typeSystem); + } + if (statement.source.expressionType == "VariableUsage") { + this.withVariableUsage(ctx, statement.source, typeSystem); + } + this.withExpression(ctx, statement.expression, typeSystem); + typeSystem.compare({ + left: statement.source.type, + operation: { operation: "equals" }, + right: statement.expression.type, + }); + }; + + withLetStatement = (ctx: Context, statement: LetStatement, typeSystem: TypeSystem) => { + this.withExpression(ctx, statement.expression, typeSystem); + typeSystem.compare({ + left: statement.type, + operation: { operation: "equals" }, + right: statement.expression.type, + }); + }; + + withReturnStatement = (ctx: Context, statement: ReturnStatement, typeSystem: TypeSystem) => { + this.withExpression(ctx, statement.source, typeSystem); + if (ctx.currentFunctionReturn) { + typeSystem.compare({ + left: ctx.currentFunctionReturn, + operation: { operation: "equals" }, + right: statement.source.type, + }); + } + }; + + withExpression = (ctx: Context, expression: Expression, typeSystem: TypeSystem) => { + if (expression.subExpression.expressionType === "LiteralInt") { + // LiteralInt always has type + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "LiteralFloat") { + // LiteralFloat always has type + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "LiteralString") { + // LiteralString always has type + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "LiteralBool") { + // LiteralBool always has type + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "LiteralStruct") { + this.withLiteralStruct(ctx, expression.subExpression, typeSystem); + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "FunctionCall") { + this.withFunctionCall(ctx, expression.subExpression, typeSystem); + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "VariableUsage") { + this.withVariableUsage(ctx, expression.subExpression, typeSystem); + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "IfExpression") { + this.withIfExpression(ctx, expression.subExpression, typeSystem); + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "StructGetter") { + this.withStructGetter(ctx, expression.subExpression, typeSystem); + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "Block") { + this.withBlock(ctx, expression.subExpression, typeSystem); + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + if (expression.subExpression.expressionType === "Operation") { + this.withOperation(ctx, expression.subExpression, typeSystem); + typeSystem.compare({ + left: expression.type, + operation: { operation: "equals" }, + right: expression.subExpression.type, + }); + } + }; + + withLiteralStruct = (ctx: Context, literal: LiteralStruct, typeSystem: TypeSystem) => { + const definition = ctx.environment[literal.name.text]; + if (!definition || definition.namedEntity !== "NamedType" || !(definition.isA === "Struct")) { + throw new Error(`${literal.name.text} not found.`); + } + if (Object.keys(definition.fields).length !== literal.fields.length) { + throw new Error(`${literal.name.text} has mismatched fields.`); + } + if (new Set(Object.keys(definition.fields)).size !== literal.fields.length) { + throw new Error(`${literal.name.text} has repeated fields.`); + } + for (const field of literal.fields) { + const definitionField = definition.fields[field.name.text]; + if (!definitionField) throw new Error(`Unknown field ${field.name.text}`); + this.withExpression(ctx, field.expression, typeSystem); + typeSystem.compare({ + left: definitionField, + operation: { operation: "equals" }, + right: field.expression.type, + }); + } + typeSystem.compare({ + left: { typeUsage: "NamedTypeUsage", name: literal.name }, + operation: { operation: "equals" }, + right: literal.type, + }); + }; + + withFunctionCall = (ctx: Context, fnCall: FunctionCall, typeSystem: TypeSystem) => { + this.withExpression(ctx, fnCall.source, typeSystem); + typeSystem.compare({ + left: fnCall.source.type, + operation: { operation: "return" }, + right: fnCall.type, + }); + typeSystem.compare({ + left: fnCall.source.type, + operation: { operation: "equals" }, + right: { + typeUsage: "FunctionTypeUsage", + arguments: fnCall.arguments.map((arg) => arg.type), + returnType: fnCall.type, + }, + }); + for (const [i, arg] of fnCall.arguments.entries()) { + this.withExpression(ctx, arg, typeSystem); + typeSystem.compare({ + left: fnCall.source.type, + operation: { operation: "argument", argNum: i }, + right: arg.type, + }); + } + }; + + withVariableUsage = (ctx: Context, usage: VariableUsage, typeSystem: TypeSystem) => { + const variable = ctx.environment[usage.name.text]; + if (!variable || variable.namedEntity === "NamedType") { + throw new Error(`${usage.name.text} not found.`); + } + typeSystem.compare({ + left: variable.type, + operation: { operation: "equals" }, + right: usage.type, + }); + }; + + withIfExpression = (ctx: Context, ifExpression: IfExpression, typeSystem: TypeSystem) => { + this.withExpression(ctx, ifExpression.condition, typeSystem); + typeSystem.compare({ + left: { typeUsage: "NamedTypeUsage", name: { text: "bool", spanStart: 0, spanEnd: 0 } }, + operation: { operation: "equals" }, + right: ifExpression.condition.type, + }); + this.withBlock(ctx, ifExpression.block, typeSystem); + typeSystem.compare({ + left: ifExpression.block.type, + operation: { operation: "equals" }, + right: ifExpression.type, + }); + if (ifExpression.else) { + this.withBlock(ctx, ifExpression.else, typeSystem); + typeSystem.compare({ + left: ifExpression.else.type, + operation: { operation: "equals" }, + right: ifExpression.type, + }); + } + }; + + withStructGetter = (ctx: Context, structGetter: StructGetter, typeSystem: TypeSystem) => { + this.withExpression(ctx, structGetter.source, typeSystem); + typeSystem.compare({ + left: structGetter.source.type, + operation: { operation: "field", name: structGetter.attribute.text }, + right: structGetter.type, + }); + }; + + withOperation = (ctx: Context, op: Operation, typeSystem: TypeSystem) => { + this.withExpression(ctx, op.left, typeSystem); + this.withExpression(ctx, op.right, typeSystem); + typeSystem.compare({ + left: op.left.type, + operation: { operation: "equals" }, + right: op.type, + }); + typeSystem.compare({ + left: op.right.type, + operation: { operation: "equals" }, + right: op.type, + }); + }; } diff --git a/packages/boringlang/src/types/type_system.ts b/packages/boringlang/src/types/type_system.ts index cdf01af..8ce13ed 100644 --- a/packages/boringlang/src/types/type_system.ts +++ b/packages/boringlang/src/types/type_system.ts @@ -1,12 +1,20 @@ -import { TypeUsage } from "../parse/ast"; +import { TypeUsage, UnknownTypeUsage } from "../parse/ast"; +import { newContext } from "./builtins"; +import { Context, getAttr } from "./context"; export const compareTypes = (typeA: TypeUsage, typeB: TypeUsage) => { if (typeA.typeUsage !== typeB.typeUsage) { throw Error(`Mismatched types: ${typeA.typeUsage} ${typeB.typeUsage}`); } if (typeA.typeUsage == "NamedTypeUsage" && typeB.typeUsage == "NamedTypeUsage") { + if (typeB.name.text === "Never") { + // never matches with everything + return; + } if (typeA.name.text !== typeB.name.text) { - throw Error(`Mismatched types: ${typeA.name.text} ${typeB.name.text}`); + throw Error( + `Mismatched types: ${typeA.name.text}:${typeA.name.spanStart}:${typeA.name.spanEnd} ${typeB.name.text}:${typeB.name.spanStart}:${typeB.name.spanEnd}`, + ); } } if (typeA.typeUsage == "FunctionTypeUsage" && typeB.typeUsage == "FunctionTypeUsage") { @@ -33,19 +41,26 @@ interface ReturnValue { operation: "return"; } +interface Field { + operation: "field"; + name: string; +} + interface Comparison { left: TypeUsage; - operation: Equals | Argument | ReturnValue; + operation: Equals | Argument | ReturnValue | Field; right: TypeUsage; } export class TypeSystem { comparisons: Comparison[]; result: Record; + context: Context; constructor() { this.comparisons = []; this.result = {}; + this.context = newContext(); } compare = (comparison: Comparison) => { @@ -56,6 +71,7 @@ export class TypeSystem { let foundUpdate = false; let containsUnknown = false; while (true) { + foundUpdate = false; for (const comparison of this.comparisons) { // if already found, just update comparison.left = this.resolveType(comparison.left); @@ -152,6 +168,37 @@ export class TypeSystem { compareTypes(comparison.left.returnType, comparison.right); } } + // field + if (comparison.operation.operation === "field") { + if (comparison.left.typeUsage === "UnknownTypeUsage") { + // cannot yet be resolved + continue; + } + if (comparison.left.typeUsage !== "NamedTypeUsage") { + throw Error("field on something that isn't a named type."); + } + // cannot be solved left + // solve right + if (comparison.right.typeUsage === "UnknownTypeUsage") { + foundUpdate = true; + const attrType = getAttr( + this.context, + comparison.left.name.text, + comparison.operation.name, + ); + this.result[comparison.right.name] = attrType; + comparison.right = attrType; + } + // check + if (comparison.right.typeUsage !== "UnknownTypeUsage") { + const attrType = getAttr( + this.context, + comparison.left.name.text, + comparison.operation.name, + ); + compareTypes(attrType, comparison.right); + } + } if ( comparison.left.typeUsage === "UnknownTypeUsage" || comparison.right.typeUsage === "UnknownTypeUsage"