From 66c7864df0bcc9a8f8b95ccefc9865fe7aa4305b Mon Sep 17 00:00:00 2001 From: Andrew Segavac Date: Fri, 29 Aug 2025 22:36:17 -0600 Subject: [PATCH] removed return/arg type comparisons --- .../boringlang/src/types/trait_checker.ts | 11 ++ packages/boringlang/src/types/type_checker.ts | 10 -- packages/boringlang/src/types/type_system.ts | 157 +++++++----------- 3 files changed, 71 insertions(+), 107 deletions(-) diff --git a/packages/boringlang/src/types/trait_checker.ts b/packages/boringlang/src/types/trait_checker.ts index ad5d1aa..20819f0 100644 --- a/packages/boringlang/src/types/trait_checker.ts +++ b/packages/boringlang/src/types/trait_checker.ts @@ -19,6 +19,17 @@ export default class TraitChecker { } } }; + withTrait = (trait: TraitTypeDeclaration) => { + for (const fn of trait.functions) { + if (fn.arguments.length === 0) { + throw new Error("First argument of trait method must be Self"); + } + const firstArg = fn.arguments[0]; + if (firstArg.type.typeUsage !== "NamedTypeUsage" || firstArg.type.name.text !== "Self") { + throw new Error("First argument of trait method must be Self"); + } + } + }; withImpl = (ctx: Context, impl: Impl) => { if (new Set(impl.functions.map((fn) => fn.declaration.name)).size !== impl.functions.length) { throw Error(`Duplicate functions in ${impl.struct.name.text}`); diff --git a/packages/boringlang/src/types/type_checker.ts b/packages/boringlang/src/types/type_checker.ts index 2be5323..fbdf1d6 100644 --- a/packages/boringlang/src/types/type_checker.ts +++ b/packages/boringlang/src/types/type_checker.ts @@ -363,11 +363,6 @@ export class TypeChecker { withFunctionCall = (ctx: Context, fnCall: FunctionCall, typeSystem: TypeSystem) => { this.withExpression(ctx, fnCall.source, typeSystem); - typeSystem.compare({ - left: fnCall.source.type, - operation: { operation: "return" }, - right: fnCall.type, - }); typeSystem.compare({ left: fnCall.source.type, operation: { operation: "equals" }, @@ -379,11 +374,6 @@ export class TypeChecker { }); for (const [i, arg] of fnCall.arguments.entries()) { this.withExpression(ctx, arg, typeSystem); - typeSystem.compare({ - left: fnCall.source.type, - operation: { operation: "argument", argNum: i }, - right: arg.type, - }); } }; diff --git a/packages/boringlang/src/types/type_system.ts b/packages/boringlang/src/types/type_system.ts index 8ce13ed..d366fe4 100644 --- a/packages/boringlang/src/types/type_system.ts +++ b/packages/boringlang/src/types/type_system.ts @@ -1,6 +1,6 @@ -import { TypeUsage, UnknownTypeUsage } from "../parse/ast"; +import { FunctionTypeUsage, TypeUsage, UnknownTypeUsage } from "../parse/ast"; import { newContext } from "./builtins"; -import { Context, getAttr } from "./context"; +import { Context, deepCopy, getAttr } from "./context"; export const compareTypes = (typeA: TypeUsage, typeB: TypeUsage) => { if (typeA.typeUsage !== typeB.typeUsage) { @@ -32,15 +32,6 @@ interface Equals { operation: "equals"; } -interface Argument { - operation: "argument"; - argNum: number; -} - -interface ReturnValue { - operation: "return"; -} - interface Field { operation: "field"; name: string; @@ -48,7 +39,7 @@ interface Field { interface Comparison { left: TypeUsage; - operation: Equals | Argument | ReturnValue | Field; + operation: Equals | Field; right: TypeUsage; } @@ -69,7 +60,6 @@ export class TypeSystem { solve = () => { let foundUpdate = false; - let containsUnknown = false; while (true) { foundUpdate = false; for (const comparison of this.comparisons) { @@ -78,23 +68,11 @@ export class TypeSystem { comparison.right = this.resolveType(comparison.right); // equals if (comparison.operation.operation === "equals") { - // solve left - if ( - comparison.left.typeUsage === "UnknownTypeUsage" && - comparison.right.typeUsage !== "UnknownTypeUsage" - ) { + const [result, found] = this.equateTypes(comparison.left, comparison.right); + if (found) { + comparison.left = result; + comparison.right = result; foundUpdate = true; - this.result[comparison.left.name] = comparison.right; - comparison.left = comparison.right; - } - // solve right - if ( - comparison.left.typeUsage !== "UnknownTypeUsage" && - comparison.right.typeUsage === "UnknownTypeUsage" - ) { - foundUpdate = true; - this.result[comparison.right.name] = comparison.left; - comparison.right = comparison.left; } // check if ( @@ -104,70 +82,6 @@ export class TypeSystem { compareTypes(comparison.left, comparison.right); } } - // argument - if (comparison.operation.operation === "argument") { - if (comparison.left.typeUsage !== "FunctionTypeUsage") { - throw Error("Argument for something that isn't a function"); - } - // solve left - const argument = comparison.left.arguments[comparison.operation.argNum]; - if ( - argument.typeUsage === "UnknownTypeUsage" && - comparison.right.typeUsage !== "UnknownTypeUsage" - ) { - foundUpdate = true; - this.result[argument.name] = comparison.right; - comparison.left.arguments[comparison.operation.argNum] = comparison.right; - } - // solve right - if ( - argument.typeUsage !== "UnknownTypeUsage" && - comparison.right.typeUsage === "UnknownTypeUsage" - ) { - foundUpdate = true; - this.result[comparison.right.name] = - comparison.left.arguments[comparison.operation.argNum]; - comparison.right = comparison.left.arguments[comparison.operation.argNum]; - } - // check - if ( - argument.typeUsage !== "UnknownTypeUsage" && - comparison.right.typeUsage !== "UnknownTypeUsage" - ) { - compareTypes(argument, comparison.right); - } - } - // return type - if (comparison.operation.operation === "return") { - if (comparison.left.typeUsage !== "FunctionTypeUsage") { - throw Error("return type for something that isn't a function"); - } - // solve left - if ( - comparison.left.returnType.typeUsage === "UnknownTypeUsage" && - comparison.right.typeUsage !== "UnknownTypeUsage" - ) { - foundUpdate = true; - this.result[comparison.left.returnType.name] = comparison.right; - comparison.left.returnType = comparison.right; - } - // solve right - if ( - comparison.left.returnType.typeUsage !== "UnknownTypeUsage" && - comparison.right.typeUsage === "UnknownTypeUsage" - ) { - foundUpdate = true; - this.result[comparison.right.name] = comparison.left.returnType; - comparison.right = comparison.left.returnType; - } - // check - if ( - comparison.left.returnType.typeUsage !== "UnknownTypeUsage" && - comparison.right.typeUsage !== "UnknownTypeUsage" - ) { - compareTypes(comparison.left.returnType, comparison.right); - } - } // field if (comparison.operation.operation === "field") { if (comparison.left.typeUsage === "UnknownTypeUsage") { @@ -199,10 +113,10 @@ export class TypeSystem { compareTypes(attrType, comparison.right); } } - if ( - comparison.left.typeUsage === "UnknownTypeUsage" || - comparison.right.typeUsage === "UnknownTypeUsage" - ) { + } + let containsUnknown = false; + for (const comparison of this.comparisons) { + if (this.containsUnknown(comparison.left) || this.containsUnknown(comparison.right)) { containsUnknown = true; } } @@ -216,6 +130,38 @@ export class TypeSystem { return this.result; }; + equateTypes = (left: TypeUsage, right: TypeUsage): [TypeUsage, boolean] => { + if (left.typeUsage === "UnknownTypeUsage" && right.typeUsage !== "UnknownTypeUsage") { + this.result[left.name] = right; + return [right, true]; + } + if (left.typeUsage !== "UnknownTypeUsage" && right.typeUsage === "UnknownTypeUsage") { + this.result[right.name] = left; + return [left, true]; + } + if (left.typeUsage === "FunctionTypeUsage" && right.typeUsage === "FunctionTypeUsage") { + if (left.arguments.length !== right.arguments.length) { + throw Error(`Mismatched arg lengths: ${left.arguments.length} ${right.arguments.length}`); + } + let found = false; + let fnResult: FunctionTypeUsage = deepCopy(left); + for (let i = 0; i < left.arguments.length; i++) { + const [result, wasFound] = this.equateTypes(left.arguments[i], right.arguments[i]); + if (wasFound) { + found = true; + } + fnResult.arguments[i] = result; + } + const [result, wasFound] = this.equateTypes(left.returnType, right.returnType); + if (wasFound) { + found = true; + } + fnResult.returnType = result; + return [fnResult, found]; + } + return [left, false]; + }; + resolveType = (type: TypeUsage): TypeUsage => { if (type.typeUsage === "UnknownTypeUsage") { if (this.result[type.name]) { @@ -232,4 +178,21 @@ export class TypeSystem { returnType: this.resolveType(type.returnType), }; }; + + containsUnknown = (type: TypeUsage) => { + if (type.typeUsage === "UnknownTypeUsage") { + return true; + } + if (type.typeUsage === "FunctionTypeUsage") { + for (const arg of type.arguments) { + if (this.containsUnknown(arg)) { + return true; + } + } + if (this.containsUnknown(type.returnType)) { + return true; + } + } + return false; + }; }