removed return/arg type comparisons

This commit is contained in:
2025-08-29 22:36:17 -06:00
parent b2709ffc82
commit 66c7864df0
3 changed files with 71 additions and 107 deletions

View File

@@ -19,6 +19,17 @@ export default class TraitChecker {
}
}
};
withTrait = (trait: TraitTypeDeclaration) => {
for (const fn of trait.functions) {
if (fn.arguments.length === 0) {
throw new Error("First argument of trait method must be Self");
}
const firstArg = fn.arguments[0];
if (firstArg.type.typeUsage !== "NamedTypeUsage" || firstArg.type.name.text !== "Self") {
throw new Error("First argument of trait method must be Self");
}
}
};
withImpl = (ctx: Context, impl: Impl) => {
if (new Set(impl.functions.map((fn) => fn.declaration.name)).size !== impl.functions.length) {
throw Error(`Duplicate functions in ${impl.struct.name.text}`);

View File

@@ -363,11 +363,6 @@ export class TypeChecker {
withFunctionCall = (ctx: Context, fnCall: FunctionCall, typeSystem: TypeSystem) => {
this.withExpression(ctx, fnCall.source, typeSystem);
typeSystem.compare({
left: fnCall.source.type,
operation: { operation: "return" },
right: fnCall.type,
});
typeSystem.compare({
left: fnCall.source.type,
operation: { operation: "equals" },
@@ -379,11 +374,6 @@ export class TypeChecker {
});
for (const [i, arg] of fnCall.arguments.entries()) {
this.withExpression(ctx, arg, typeSystem);
typeSystem.compare({
left: fnCall.source.type,
operation: { operation: "argument", argNum: i },
right: arg.type,
});
}
};

View File

@@ -1,6 +1,6 @@
import { TypeUsage, UnknownTypeUsage } from "../parse/ast";
import { FunctionTypeUsage, TypeUsage, UnknownTypeUsage } from "../parse/ast";
import { newContext } from "./builtins";
import { Context, getAttr } from "./context";
import { Context, deepCopy, getAttr } from "./context";
export const compareTypes = (typeA: TypeUsage, typeB: TypeUsage) => {
if (typeA.typeUsage !== typeB.typeUsage) {
@@ -32,15 +32,6 @@ interface Equals {
operation: "equals";
}
interface Argument {
operation: "argument";
argNum: number;
}
interface ReturnValue {
operation: "return";
}
interface Field {
operation: "field";
name: string;
@@ -48,7 +39,7 @@ interface Field {
interface Comparison {
left: TypeUsage;
operation: Equals | Argument | ReturnValue | Field;
operation: Equals | Field;
right: TypeUsage;
}
@@ -69,7 +60,6 @@ export class TypeSystem {
solve = () => {
let foundUpdate = false;
let containsUnknown = false;
while (true) {
foundUpdate = false;
for (const comparison of this.comparisons) {
@@ -78,23 +68,11 @@ export class TypeSystem {
comparison.right = this.resolveType(comparison.right);
// equals
if (comparison.operation.operation === "equals") {
// solve left
if (
comparison.left.typeUsage === "UnknownTypeUsage" &&
comparison.right.typeUsage !== "UnknownTypeUsage"
) {
const [result, found] = this.equateTypes(comparison.left, comparison.right);
if (found) {
comparison.left = result;
comparison.right = result;
foundUpdate = true;
this.result[comparison.left.name] = comparison.right;
comparison.left = comparison.right;
}
// solve right
if (
comparison.left.typeUsage !== "UnknownTypeUsage" &&
comparison.right.typeUsage === "UnknownTypeUsage"
) {
foundUpdate = true;
this.result[comparison.right.name] = comparison.left;
comparison.right = comparison.left;
}
// check
if (
@@ -104,70 +82,6 @@ export class TypeSystem {
compareTypes(comparison.left, comparison.right);
}
}
// argument
if (comparison.operation.operation === "argument") {
if (comparison.left.typeUsage !== "FunctionTypeUsage") {
throw Error("Argument for something that isn't a function");
}
// solve left
const argument = comparison.left.arguments[comparison.operation.argNum];
if (
argument.typeUsage === "UnknownTypeUsage" &&
comparison.right.typeUsage !== "UnknownTypeUsage"
) {
foundUpdate = true;
this.result[argument.name] = comparison.right;
comparison.left.arguments[comparison.operation.argNum] = comparison.right;
}
// solve right
if (
argument.typeUsage !== "UnknownTypeUsage" &&
comparison.right.typeUsage === "UnknownTypeUsage"
) {
foundUpdate = true;
this.result[comparison.right.name] =
comparison.left.arguments[comparison.operation.argNum];
comparison.right = comparison.left.arguments[comparison.operation.argNum];
}
// check
if (
argument.typeUsage !== "UnknownTypeUsage" &&
comparison.right.typeUsage !== "UnknownTypeUsage"
) {
compareTypes(argument, comparison.right);
}
}
// return type
if (comparison.operation.operation === "return") {
if (comparison.left.typeUsage !== "FunctionTypeUsage") {
throw Error("return type for something that isn't a function");
}
// solve left
if (
comparison.left.returnType.typeUsage === "UnknownTypeUsage" &&
comparison.right.typeUsage !== "UnknownTypeUsage"
) {
foundUpdate = true;
this.result[comparison.left.returnType.name] = comparison.right;
comparison.left.returnType = comparison.right;
}
// solve right
if (
comparison.left.returnType.typeUsage !== "UnknownTypeUsage" &&
comparison.right.typeUsage === "UnknownTypeUsage"
) {
foundUpdate = true;
this.result[comparison.right.name] = comparison.left.returnType;
comparison.right = comparison.left.returnType;
}
// check
if (
comparison.left.returnType.typeUsage !== "UnknownTypeUsage" &&
comparison.right.typeUsage !== "UnknownTypeUsage"
) {
compareTypes(comparison.left.returnType, comparison.right);
}
}
// field
if (comparison.operation.operation === "field") {
if (comparison.left.typeUsage === "UnknownTypeUsage") {
@@ -199,10 +113,10 @@ export class TypeSystem {
compareTypes(attrType, comparison.right);
}
}
if (
comparison.left.typeUsage === "UnknownTypeUsage" ||
comparison.right.typeUsage === "UnknownTypeUsage"
) {
}
let containsUnknown = false;
for (const comparison of this.comparisons) {
if (this.containsUnknown(comparison.left) || this.containsUnknown(comparison.right)) {
containsUnknown = true;
}
}
@@ -216,6 +130,38 @@ export class TypeSystem {
return this.result;
};
equateTypes = (left: TypeUsage, right: TypeUsage): [TypeUsage, boolean] => {
if (left.typeUsage === "UnknownTypeUsage" && right.typeUsage !== "UnknownTypeUsage") {
this.result[left.name] = right;
return [right, true];
}
if (left.typeUsage !== "UnknownTypeUsage" && right.typeUsage === "UnknownTypeUsage") {
this.result[right.name] = left;
return [left, true];
}
if (left.typeUsage === "FunctionTypeUsage" && right.typeUsage === "FunctionTypeUsage") {
if (left.arguments.length !== right.arguments.length) {
throw Error(`Mismatched arg lengths: ${left.arguments.length} ${right.arguments.length}`);
}
let found = false;
let fnResult: FunctionTypeUsage = deepCopy(left);
for (let i = 0; i < left.arguments.length; i++) {
const [result, wasFound] = this.equateTypes(left.arguments[i], right.arguments[i]);
if (wasFound) {
found = true;
}
fnResult.arguments[i] = result;
}
const [result, wasFound] = this.equateTypes(left.returnType, right.returnType);
if (wasFound) {
found = true;
}
fnResult.returnType = result;
return [fnResult, found];
}
return [left, false];
};
resolveType = (type: TypeUsage): TypeUsage => {
if (type.typeUsage === "UnknownTypeUsage") {
if (this.result[type.name]) {
@@ -232,4 +178,21 @@ export class TypeSystem {
returnType: this.resolveType(type.returnType),
};
};
containsUnknown = (type: TypeUsage) => {
if (type.typeUsage === "UnknownTypeUsage") {
return true;
}
if (type.typeUsage === "FunctionTypeUsage") {
for (const arg of type.arguments) {
if (this.containsUnknown(arg)) {
return true;
}
}
if (this.containsUnknown(type.returnType)) {
return true;
}
}
return false;
};
}