got type system working

This commit is contained in:
2025-08-25 21:51:50 -06:00
parent 0a315c5615
commit b2709ffc82
10 changed files with 858 additions and 91 deletions

View File

@@ -1,11 +1,6 @@
fn main(): String { fn main(): String {
let a = 2; let a = "asdf";
a; return a ;
a = 3;
a = if(true) {"asdf"} else {"fdsa"};
a.b.c.d();
a = (b + c.d()); // comment
return a.b() ;
} }
type User struct { type User struct {
@@ -23,7 +18,7 @@ impl TestTrait for User {
return Self{id: id}; return Self{id: id};
} }
fn instance_method(self: Self): i64 { fn instance_method(self: Self): i64 {
return self.get_id(); return self.id;
} }
fn default_impl(self: Self): i64 { fn default_impl(self: Self): i64 {
return self.instance_method(); return self.instance_method();

View File

@@ -2,6 +2,9 @@ import { defineCommand } from "@bunli/core";
import { boringGrammar } from "../parse/grammar"; import { boringGrammar } from "../parse/grammar";
import { semantics } from "../parse/semantics"; import { semantics } from "../parse/semantics";
import TraitChecker from "../types/trait_checker"; import TraitChecker from "../types/trait_checker";
import { TypeAliasResolver } from "../types/type_alias_resolution";
import { TypeSystem } from "../types/type_system";
import { TypeChecker } from "../types/type_checker";
export const run = defineCommand({ export const run = defineCommand({
name: "run", name: "run",
@@ -22,8 +25,14 @@ export const run = defineCommand({
const adapter = semantics(match); const adapter = semantics(match);
const ast = adapter.toAST(); const ast = adapter.toAST();
new TraitChecker().withModule(ast); new TraitChecker().withModule(ast);
const aliasResolvedAst = new TypeAliasResolver().withModule(ast);
const typeSystem = new TypeSystem();
const typeChecker = new TypeChecker();
typeChecker.withModule(aliasResolvedAst, typeSystem);
typeSystem.solve();
console.log(JSON.stringify(typeSystem, null, 2));
console.log(JSON.stringify(ast, null, 2)); // console.log(JSON.stringify(aliasResolvedAst, null, 2));
} else { } else {
console.log(match.message); console.log(match.message);
// console.log(boringGrammar.trace(text, "Module").toString()); // console.log(boringGrammar.trace(text, "Module").toString());

View File

@@ -62,6 +62,7 @@ export interface Operation {
left: Expression; left: Expression;
op: "+" | "-" | "*" | "/"; op: "+" | "-" | "*" | "/";
right: Expression; right: Expression;
type: TypeUsage;
} }
export interface VariableUsage { export interface VariableUsage {
@@ -116,6 +117,7 @@ export interface AssignmentStatement {
export type Statement = ReturnStatement | LetStatement | AssignmentStatement | Expression; export type Statement = ReturnStatement | LetStatement | AssignmentStatement | Expression;
export interface Block { export interface Block {
expressionType: "Block";
statements: Statement[]; statements: Statement[];
type: TypeUsage; type: TypeUsage;
} }
@@ -196,3 +198,74 @@ export interface UnknownTypeUsage {
} }
export type TypeUsage = NamedTypeUsage | FunctionTypeUsage | UnknownTypeUsage; export type TypeUsage = NamedTypeUsage | FunctionTypeUsage | UnknownTypeUsage;
export const newVoid: () => TypeUsage = () => {
return {
typeUsage: "NamedTypeUsage",
name: { text: "Void", spanStart: 0, spanEnd: 0 },
};
};
export const newNever: () => TypeUsage = () => {
return {
typeUsage: "NamedTypeUsage",
name: { text: "Never", spanStart: 0, spanEnd: 0 },
};
};
function containsReturnExpression(expression: Expression) {
if (expression.subExpression.expressionType === "LiteralStruct") {
for (const field of expression.subExpression.fields) {
if (containsReturnExpression(field.expression)) {
return true;
}
}
}
if (expression.subExpression.expressionType === "IfExpression") {
if (containsReturn(expression.subExpression.block)) {
return true;
}
if (expression.subExpression.else && containsReturn(expression.subExpression.else)) {
return true;
}
}
if (expression.subExpression.expressionType === "Block") {
if (containsReturn(expression.subExpression)) {
return true;
}
}
if (expression.subExpression.expressionType === "Operation") {
if (containsReturnExpression(expression.subExpression.left)) {
return true;
}
if (containsReturnExpression(expression.subExpression.right)) {
return true;
}
}
return false;
}
export function containsReturn(block: Block) {
for (const statement of block.statements) {
if (statement.statementType === "ReturnStatement") {
return true;
}
if (statement.statementType === "AssignmentStatement") {
if (containsReturnExpression(statement.expression)) {
return true;
}
}
if (statement.statementType === "LetStatement") {
if (containsReturnExpression(statement.expression)) {
return true;
}
}
if (statement.statementType === "Expression") {
if (containsReturnExpression(statement)) {
return true;
}
}
}
return false;
}

View File

@@ -1,6 +1,7 @@
import { import {
AssignmentStatement, AssignmentStatement,
Block, Block,
containsReturn,
Expression, Expression,
Function, Function,
FunctionArgument, FunctionArgument,
@@ -19,6 +20,7 @@ import {
Module, Module,
ModuleItem, ModuleItem,
NamedTypeUsage, NamedTypeUsage,
newNever,
Operation, Operation,
ReturnStatement, ReturnStatement,
Statement, Statement,
@@ -43,8 +45,6 @@ function nextUnknown() {
export const semantics = boringGrammar.createSemantics(); export const semantics = boringGrammar.createSemantics();
semantics.addOperation<any>("toAST", { semantics.addOperation<any>("toAST", {
LiteralInt(a): LiteralInt { LiteralInt(a): LiteralInt {
console.log(this);
console.log(a.source.startIdx);
return { return {
expressionType: "LiteralInt", expressionType: "LiteralInt",
value: this.sourceString, value: this.sourceString,
@@ -106,11 +106,15 @@ semantics.addOperation<any>("toAST", {
}; };
}, },
FunctionCall(expression, _2, args, _4): FunctionCall { FunctionCall(expression, _2, args, _4): FunctionCall {
const resolvedArgs = args.asIteration().children.map((c) => c.toAST());
return { return {
expressionType: "FunctionCall", expressionType: "FunctionCall",
source: expression.toAST(), source: expression.toAST(),
arguments: args.asIteration().children.map((c) => c.toAST()), arguments: resolvedArgs,
type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, type: {
typeUsage: "UnknownTypeUsage",
name: nextUnknown(),
},
}; };
}, },
StructGetter(expression, _2, identifier): StructGetter { StructGetter(expression, _2, identifier): StructGetter {
@@ -148,7 +152,11 @@ semantics.addOperation<any>("toAST", {
return factor.toAST(); return factor.toAST();
}, },
Expression(expression): Expression { Expression(expression): Expression {
return expression.toAST(); return {
statementType: "Expression",
subExpression: expression.toAST(),
type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() },
};
}, },
Expression_plus(expression, _2, factor): Operation { Expression_plus(expression, _2, factor): Operation {
return { return {
@@ -156,6 +164,7 @@ semantics.addOperation<any>("toAST", {
left: expression.toAST(), left: expression.toAST(),
op: "+", op: "+",
right: factor.toAST(), right: factor.toAST(),
type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() },
}; };
}, },
Expression_minus(expression, _2, factor): Operation { Expression_minus(expression, _2, factor): Operation {
@@ -164,6 +173,7 @@ semantics.addOperation<any>("toAST", {
left: expression.toAST(), left: expression.toAST(),
op: "-", op: "-",
right: factor.toAST(), right: factor.toAST(),
type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() },
}; };
}, },
Factor_mult(factor, _2, term): Operation { Factor_mult(factor, _2, term): Operation {
@@ -172,6 +182,7 @@ semantics.addOperation<any>("toAST", {
left: factor.toAST(), left: factor.toAST(),
op: "*", op: "*",
right: term.toAST(), right: term.toAST(),
type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() },
}; };
}, },
Factor_div(factor, _2, term): Operation { Factor_div(factor, _2, term): Operation {
@@ -180,6 +191,7 @@ semantics.addOperation<any>("toAST", {
left: factor.toAST(), left: factor.toAST(),
op: "/", op: "/",
right: term.toAST(), right: term.toAST(),
type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() },
}; };
}, },
Statement(statement): Statement { Statement(statement): Statement {
@@ -217,11 +229,18 @@ semantics.addOperation<any>("toAST", {
Block(_1, statements, expression, _4): Block { Block(_1, statements, expression, _4): Block {
const lines = statements.asIteration().children.map((c) => c.toAST()); const lines = statements.asIteration().children.map((c) => c.toAST());
const finalExpression = expression.toAST(); const finalExpression = expression.toAST();
lines.push(finalExpression.length > 0 ? finalExpression[0] : null); if (finalExpression.length > 0) {
return { lines.push(finalExpression[0]);
}
const block: Block = {
expressionType: "Block",
statements: lines, statements: lines,
type: { typeUsage: "UnknownTypeUsage", name: nextUnknown() }, type: newNever(),
}; };
if (!containsReturn(block)) {
block.type = { typeUsage: "UnknownTypeUsage", name: nextUnknown() };
}
return block;
}, },
NamedTypeUsage(name): NamedTypeUsage { NamedTypeUsage(name): NamedTypeUsage {
return { return {

View File

@@ -0,0 +1,21 @@
import { Context } from "./context";
export function newContext(): Context {
const result: Context = {
currentFunctionReturn: null,
environment: {},
};
result.environment["i8"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["i16"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["i32"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["i64"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["f8"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["f16"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["f32"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["f64"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["String"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["Void"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
result.environment["Never"] = { namedEntity: "NamedType", isA: "Scalar", fields: {}, impls: [] };
return result;
}

View File

@@ -0,0 +1,86 @@
import { NamedTypeUsage, TypeUsage } from "../parse/ast";
interface EnvImpl {
trait: string | null;
functions: Record<string, TypeUsage>;
}
interface NamedType {
namedEntity: "NamedType";
isA: "Scalar" | "Trait" | "Struct";
fields: Record<string, TypeUsage>;
impls: EnvImpl[];
}
interface Variable {
namedEntity: "Variable";
type: TypeUsage;
}
type NamedEntity = NamedType | Variable;
export interface Context {
currentFunctionReturn: TypeUsage | null;
environment: Record<string, NamedEntity>;
}
export function getAttr(ctx: Context, name: string, field: string) {
const struct = ctx.environment[name];
if (!struct || struct.namedEntity !== "NamedType") {
throw Error(`Unknown type ${name}`);
}
if (struct.fields[field]) {
return struct.fields[field];
}
let results: TypeUsage[] = [];
for (const impl of struct.impls) {
if (impl.functions[field]) {
results.push(impl.functions[field]);
}
}
if (results.length === 0) {
console.log(JSON.stringify(struct, null, 2));
throw Error(`${name} has no attribue ${field}`);
}
if (results.length > 1) {
throw Error(`${name} has multiple attribues ${field}, use universal function call syntax`);
}
return results[0];
}
export function typeExists(ctx: Context, type: TypeUsage) {
if (type.typeUsage === "NamedTypeUsage") {
if (
!ctx.environment[type.name.text] ||
ctx.environment[type.name.text].namedEntity !== "NamedType"
) {
throw Error(`${type.name.text} is not a type.`);
}
}
if (type.typeUsage === "FunctionTypeUsage") {
for (const arg of type.arguments) {
typeExists(ctx, arg);
}
typeExists(ctx, type.returnType);
}
}
export function replaceType(oldName: string, newType: TypeUsage, inType: TypeUsage) {
if (inType.typeUsage === "NamedTypeUsage") {
if (inType.name.text === oldName) {
return deepCopy(newType);
}
}
if (inType.typeUsage === "FunctionTypeUsage") {
const result = deepCopy(inType);
for (const [i, arg] of inType.arguments.entries()) {
result.arguments[i] = replaceType(oldName, newType, arg);
}
result.returnType = replaceType(oldName, newType, inType.returnType);
}
return deepCopy(inType);
}
export const deepCopy = <T>(o: T) => {
return JSON.parse(JSON.stringify(o)) as T;
};

View File

@@ -10,7 +10,7 @@ export default class TraitChecker {
let ctx: Context = { environment: {} }; let ctx: Context = { environment: {} };
for (const item of module.items) { for (const item of module.items) {
if (item.moduleItem == "TraitTypeDeclaration") { if (item.moduleItem == "TraitTypeDeclaration") {
ctx.environment[item.name.name] = item; ctx.environment[item.name.text] = item;
} }
} }
for (const item of module.items) { for (const item of module.items) {
@@ -21,20 +21,20 @@ export default class TraitChecker {
}; };
withImpl = (ctx: Context, impl: Impl) => { withImpl = (ctx: Context, impl: Impl) => {
if (new Set(impl.functions.map((fn) => fn.declaration.name)).size !== impl.functions.length) { if (new Set(impl.functions.map((fn) => fn.declaration.name)).size !== impl.functions.length) {
throw Error(`Duplicate functions in ${impl.struct.name.name}`); throw Error(`Duplicate functions in ${impl.struct.name.text}`);
} }
if (impl.trait == null) { if (impl.trait == null) {
return; return;
} }
const trait = ctx.environment[impl.trait.name.name]; const trait = ctx.environment[impl.trait.name.text];
if (!trait) { if (!trait) {
throw Error(`No such trait: ${impl.trait.name}`); throw Error(`No such trait: ${impl.trait.name}`);
} }
if (impl.functions.length !== trait.functions.length) { if (impl.functions.length !== trait.functions.length) {
throw Error(`Mismatched impl/trait len ${impl.trait.name.name} for ${impl.struct.name.name}`); throw Error(`Mismatched impl/trait len ${impl.trait.name.text} for ${impl.struct.name.text}`);
} }
for (let i = 0; i < impl.functions.length; i++) { for (let i = 0; i < impl.functions.length; i++) {
if (impl.functions[i].declaration.name.name !== trait.functions[i].name.name) { if (impl.functions[i].declaration.name.text !== trait.functions[i].name.text) {
throw Error( throw Error(
`Mismatched impl/trait names ${impl.functions[i].declaration.name} for ${trait.functions[i].name}`, `Mismatched impl/trait names ${impl.functions[i].declaration.name} for ${trait.functions[i].name}`,
); );

View File

@@ -0,0 +1,238 @@
import {
AssignmentStatement,
Block,
Expression,
Function,
FunctionCall,
FunctionDeclaration,
IfExpression,
Impl,
LetStatement,
LiteralStruct,
Module,
Operation,
ReturnStatement,
StructGetter,
StructTypeDeclaration,
TraitTypeDeclaration,
TypeUsage,
VariableUsage,
} from "../parse/ast";
import { deepCopy, replaceType } from "./context";
interface AliasContext {
environment: Record<string, TypeUsage>;
}
export class TypeAliasResolver {
withModule = (module: Module) => {
const ctx: AliasContext = { environment: {} };
const result = deepCopy(module);
for (const [i, item] of module.items.entries()) {
if (item.moduleItem === "TraitTypeDeclaration") {
let traitCtx = deepCopy(ctx);
traitCtx.environment["Self"] = { typeUsage: "NamedTypeUsage", name: item.name };
result.items[i] = this.withTraitTypeDeclaration(traitCtx, item);
}
if (item.moduleItem === "Impl") {
let implCtx = deepCopy(ctx);
implCtx.environment["Self"] = { typeUsage: "NamedTypeUsage", name: item.struct.name };
result.items[i] = this.withImpl(implCtx, item);
}
if (item.moduleItem === "StructTypeDeclaration") {
let structCtx = deepCopy(ctx);
structCtx.environment["Self"] = { typeUsage: "NamedTypeUsage", name: item.name };
result.items[i] = this.withStructDeclaration(structCtx, item);
}
}
return result;
};
withTraitTypeDeclaration = (ctx: AliasContext, trait: TraitTypeDeclaration) => {
const result = deepCopy(trait);
for (const [oldName, newType] of Object.entries(ctx.environment)) {
for (const [i, fn] of trait.functions.entries()) {
result.functions[i] = this.withFunctionDeclaration(ctx, fn);
}
}
return result;
};
withImpl = (ctx: AliasContext, impl: Impl) => {
const result = deepCopy(impl);
for (const [oldName, newType] of Object.entries(ctx.environment)) {
for (const [i, fn] of impl.functions.entries()) {
result.functions[i] = this.withFunction(ctx, fn);
}
}
return result;
};
withStructDeclaration = (ctx: AliasContext, struct: StructTypeDeclaration) => {
const result = deepCopy(struct);
for (const [oldName, newType] of Object.entries(ctx.environment)) {
for (const [i, field] of struct.fields.entries()) {
result.fields[i].type = replaceType(oldName, newType, field.type);
}
}
return result;
};
withFunctionDeclaration = (ctx: AliasContext, fn: FunctionDeclaration) => {
const result = deepCopy(fn);
for (const [oldName, newType] of Object.entries(ctx.environment)) {
for (const [i, arg] of fn.arguments.entries()) {
result.arguments[i].type = replaceType(oldName, newType, arg.type);
}
result.returnType = replaceType(oldName, newType, result.returnType);
}
return result;
};
withFunction = (ctx: AliasContext, fn: Function) => {
const result = deepCopy(fn);
result.declaration = this.withFunctionDeclaration(ctx, fn.declaration);
result.block = this.withBlock(ctx, fn.block);
return result;
};
withBlock = (ctx: AliasContext, block: Block) => {
const result = deepCopy(block);
for (const [i, statement] of block.statements.entries()) {
if (statement.statementType === "AssignmentStatement") {
result.statements[i] = this.withAssignmentStatement(ctx, statement);
}
if (statement.statementType === "LetStatement") {
result.statements[i] = this.withLetStatement(ctx, statement);
}
if (statement.statementType === "Expression") {
result.statements[i] = this.withExpression(ctx, statement);
}
if (statement.statementType === "ReturnStatement") {
result.statements[i] = this.withReturnStatement(ctx, statement);
}
for (const [oldName, newType] of Object.entries(ctx.environment)) {
result.type = replaceType(oldName, newType, block.type);
}
}
return result;
};
withAssignmentStatement = (ctx: AliasContext, statement: AssignmentStatement) => {
const result = deepCopy(statement);
if (statement.source.expressionType == "StructGetter") {
result.source = this.withStructGetter(ctx, statement.source);
}
if (statement.source.expressionType == "VariableUsage") {
result.source = this.withVariableUsage(ctx, statement.source);
}
result.expression = this.withExpression(ctx, statement.expression);
return result;
};
withLetStatement = (ctx: AliasContext, statement: LetStatement) => {
const result = deepCopy(statement);
result.expression = this.withExpression(ctx, statement.expression);
return result;
};
withReturnStatement = (ctx: AliasContext, statement: ReturnStatement) => {
const result = deepCopy(statement);
result.source = this.withExpression(ctx, statement.source);
return result;
};
withExpression = (ctx: AliasContext, expression: Expression) => {
const result = deepCopy(expression);
if (expression.subExpression.expressionType === "LiteralStruct") {
result.subExpression = this.withLiteralStruct(ctx, expression.subExpression);
}
if (expression.subExpression.expressionType === "FunctionCall") {
result.subExpression = this.withFunctionCall(ctx, expression.subExpression);
}
if (expression.subExpression.expressionType === "VariableUsage") {
result.subExpression = this.withVariableUsage(ctx, expression.subExpression);
}
if (expression.subExpression.expressionType === "IfExpression") {
result.subExpression = this.withIfExpression(ctx, expression.subExpression);
}
if (expression.subExpression.expressionType === "StructGetter") {
result.subExpression = this.withStructGetter(ctx, expression.subExpression);
}
if (expression.subExpression.expressionType === "Block") {
result.subExpression = this.withBlock(ctx, expression.subExpression);
}
if (expression.subExpression.expressionType === "Operation") {
result.subExpression = this.withOperation(ctx, expression.subExpression);
}
return result;
};
withLiteralStruct = (ctx: AliasContext, literalStruct: LiteralStruct) => {
const result = deepCopy(literalStruct);
for (const [i, field] of literalStruct.fields.entries()) {
result.fields[i].expression = this.withExpression(ctx, field.expression);
}
for (const [oldName, newType] of Object.entries(ctx.environment)) {
result.type = replaceType(oldName, newType, literalStruct.type);
if (result.name.text === oldName && newType.typeUsage === "NamedTypeUsage") {
result.name = newType.name;
}
}
return result;
};
withFunctionCall = (ctx: AliasContext, fnCall: FunctionCall) => {
const result = deepCopy(fnCall);
for (const [i, arg] of fnCall.arguments.entries()) {
result.arguments[i] = this.withExpression(ctx, arg);
}
result.source = this.withExpression(ctx, fnCall.source);
for (const [oldName, newType] of Object.entries(ctx.environment)) {
result.type = replaceType(oldName, newType, fnCall.type);
}
return result;
};
withVariableUsage = (ctx: AliasContext, variableUsage: VariableUsage) => {
const result = deepCopy(variableUsage);
for (const [oldName, newType] of Object.entries(ctx.environment)) {
result.type = replaceType(oldName, newType, variableUsage.type);
}
return result;
};
withIfExpression = (ctx: AliasContext, ifExpression: IfExpression) => {
const result = deepCopy(ifExpression);
result.condition = this.withExpression(ctx, ifExpression.condition);
result.block = this.withBlock(ctx, ifExpression.block);
if (ifExpression.else) {
result.else = this.withBlock(ctx, ifExpression.else);
}
for (const [oldName, newType] of Object.entries(ctx.environment)) {
result.type = replaceType(oldName, newType, ifExpression.type);
}
return result;
};
withStructGetter = (ctx: AliasContext, structGetter: StructGetter) => {
const result = deepCopy(structGetter);
result.source = this.withExpression(ctx, structGetter.source);
for (const [oldName, newType] of Object.entries(ctx.environment)) {
result.type = replaceType(oldName, newType, structGetter.type);
}
return result;
};
withOperation = (ctx: AliasContext, op: Operation) => {
const result = deepCopy(op);
result.left = this.withExpression(ctx, op.left);
result.right = this.withExpression(ctx, op.right);
for (const [oldName, newType] of Object.entries(ctx.environment)) {
result.type = replaceType(oldName, newType, op.type);
}
return result;
};
}

View File

@@ -1,66 +1,33 @@
import { import {
AssignmentStatement,
Block,
containsReturn,
Expression,
Function, Function,
FunctionCall,
FunctionDeclaration, FunctionDeclaration,
functionToType, functionToType,
IfExpression,
Impl, Impl,
LetStatement,
LiteralStruct,
Module, Module,
newVoid,
Operation,
ReturnStatement,
StructGetter,
StructTypeDeclaration, StructTypeDeclaration,
TraitTypeDeclaration, TraitTypeDeclaration,
TypeUsage, TypeUsage,
VariableUsage,
} from "../parse/ast"; } from "../parse/ast";
import { newContext } from "./builtins";
import { Context, deepCopy, typeExists } from "./context";
import { TypeSystem } from "./type_system"; import { TypeSystem } from "./type_system";
interface EnvImpl { export class TypeChecker {
trait: string | null;
functions: Record<string, TypeUsage>;
}
interface NamedType {
namedEntity: "NamedType";
isA: "Scalar" | "Trait" | "Struct";
fields: Record<string, TypeUsage>;
impls: EnvImpl[];
}
interface Variable {
namedEntity: "Variable";
type: TypeUsage;
}
type NamedEntity = NamedType | Variable;
interface Context {
currentFunctionReturn: TypeUsage | null;
environment: Record<string, NamedEntity>;
}
function typeExists(ctx: Context, type: TypeUsage) {
if (type.typeUsage === "NamedTypeUsage") {
if (
!ctx.environment[type.name.text] ||
ctx.environment[type.name.text].namedEntity !== "NamedType"
) {
throw Error(`${type.name.text} is not a type.`);
}
}
if (type.typeUsage === "FunctionTypeUsage") {
for (const arg of type.arguments) {
typeExists(ctx, arg);
}
typeExists(ctx, type.returnType);
}
}
const deepCopy = <T>(o: T) => {
return JSON.parse(JSON.stringify(o)) as T;
};
class TypeChecker {
withModule = (module: Module, typeSystem: TypeSystem) => { withModule = (module: Module, typeSystem: TypeSystem) => {
const ctx: Context = { const ctx: Context = newContext();
currentFunctionReturn: null,
environment: {},
};
// add functions, structs, and traits to the context // add functions, structs, and traits to the context
for (const item of module.items) { for (const item of module.items) {
if (item.moduleItem === "StructTypeDeclaration") { if (item.moduleItem === "StructTypeDeclaration") {
@@ -70,7 +37,7 @@ class TypeChecker {
ctx.environment[item.name.text] = { ctx.environment[item.name.text] = {
namedEntity: "NamedType", namedEntity: "NamedType",
isA: "Struct", isA: "Struct",
fields: Object.fromEntries(item.fields.map((field) => [field.name, field.type])), fields: Object.fromEntries(item.fields.map((field) => [field.name.text, field.type])),
impls: [], impls: [],
}; };
} }
@@ -78,6 +45,18 @@ class TypeChecker {
if (ctx.environment[item.name.text]) { if (ctx.environment[item.name.text]) {
throw Error("Duplicate name of trait"); throw Error("Duplicate name of trait");
} }
const functions: Record<string, TypeUsage> = {};
for (const fn of item.functions) {
if (
fn.arguments.length &&
fn.arguments[0].type.typeUsage == "NamedTypeUsage" &&
fn.arguments[0].type.name.text === item.name.text
) {
const fnCopy = deepCopy(fn);
fnCopy.arguments = fnCopy.arguments.slice(1);
functions[fn.name.text] = functionToType(fnCopy);
}
}
ctx.environment[item.name.text] = { ctx.environment[item.name.text] = {
namedEntity: "NamedType", namedEntity: "NamedType",
isA: "Trait", isA: "Trait",
@@ -85,11 +64,7 @@ class TypeChecker {
impls: [ impls: [
{ {
trait: item.name.text, trait: item.name.text,
functions: Object.fromEntries( functions: functions,
item.functions.map((fn) => {
return [fn.name, functionToType(fn)];
}),
),
}, },
], ],
}; };
@@ -111,16 +86,25 @@ class TypeChecker {
if (!struct || struct.namedEntity !== "NamedType" || struct.isA !== "Struct") { if (!struct || struct.namedEntity !== "NamedType" || struct.isA !== "Struct") {
throw Error("Impl for non-struct"); throw Error("Impl for non-struct");
} }
const functions: Record<string, TypeUsage> = {};
for (const fn of item.functions) {
if (
fn.declaration.arguments.length &&
fn.declaration.arguments[0].type.typeUsage == "NamedTypeUsage" &&
fn.declaration.arguments[0].type.name.text === item.struct.name.text
) {
const fnCopy = deepCopy(fn.declaration);
fnCopy.arguments = fnCopy.arguments.slice(1);
functions[fn.declaration.name.text] = functionToType(fnCopy);
}
}
struct.impls.push({ struct.impls.push({
trait: item.trait?.name.text ?? null, trait: item.trait?.name.text ?? null,
functions: Object.fromEntries( functions: functions,
item.functions.map((fn) => {
return [fn.declaration.name, functionToType(fn.declaration)];
}),
),
}); });
} }
} }
typeSystem.context = deepCopy(ctx);
// environment set up, actually recurse. // environment set up, actually recurse.
for (const item of module.items) { for (const item of module.items) {
if (item.moduleItem === "Function") { if (item.moduleItem === "Function") {
@@ -130,7 +114,7 @@ class TypeChecker {
this.withImpl(ctx, item, typeSystem); this.withImpl(ctx, item, typeSystem);
} }
if (item.moduleItem === "StructTypeDeclaration") { if (item.moduleItem === "StructTypeDeclaration") {
this.withStruct(ctx, item, typeSystem); this.withStructDeclaration(ctx, item, typeSystem);
} }
if (item.moduleItem === "TraitTypeDeclaration") { if (item.moduleItem === "TraitTypeDeclaration") {
this.withTrait(ctx, item, typeSystem); this.withTrait(ctx, item, typeSystem);
@@ -141,13 +125,14 @@ class TypeChecker {
withFunction = (ctx: Context, fn: Function, typeSystem: TypeSystem) => { withFunction = (ctx: Context, fn: Function, typeSystem: TypeSystem) => {
this.withFunctionDeclaration(ctx, fn.declaration, typeSystem); this.withFunctionDeclaration(ctx, fn.declaration, typeSystem);
const fnCtx = deepCopy(ctx); const fnCtx = deepCopy(ctx);
fnCtx.currentFunctionReturn = fn.declaration.returnType;
for (const arg of fn.declaration.arguments) { for (const arg of fn.declaration.arguments) {
fnCtx.environment[arg.name.text] = { fnCtx.environment[arg.name.text] = {
namedEntity: "Variable", namedEntity: "Variable",
type: arg.type, type: arg.type,
}; };
} }
// this.withBlock(fnCtx, fn.block, typeSystem); this.withBlock(fnCtx, fn.block, typeSystem);
typeSystem.compare({ typeSystem.compare({
left: fn.declaration.returnType, left: fn.declaration.returnType,
operation: { operation: "equals" }, operation: { operation: "equals" },
@@ -162,8 +147,302 @@ class TypeChecker {
typeExists(ctx, def.returnType); typeExists(ctx, def.returnType);
}; };
withImpl = (ctx: Context, impl: Impl, typeSystem: TypeSystem) => {}; withImpl = (ctx: Context, impl: Impl, typeSystem: TypeSystem) => {
for (const fn of impl.functions) {
this.withFunction(ctx, fn, typeSystem);
}
};
withStruct = (ctx: Context, struct: StructTypeDeclaration, typeSystem: TypeSystem) => {}; withStructDeclaration = (ctx: Context, struct: StructTypeDeclaration, typeSystem: TypeSystem) => {
withTrait = (ctx: Context, trait: TraitTypeDeclaration, typeSystem: TypeSystem) => {}; for (const field of struct.fields) {
typeExists(ctx, field.type);
}
};
withTrait = (ctx: Context, trait: TraitTypeDeclaration, typeSystem: TypeSystem) => {
for (const method of trait.functions) {
this.withFunctionDeclaration(ctx, method, typeSystem);
}
};
withBlock = (ctx: Context, block: Block, typeSystem: TypeSystem) => {
const blockCtx = deepCopy(ctx);
for (const statement of block.statements) {
if (statement.statementType === "AssignmentStatement") {
this.withAssignmentStatement(blockCtx, statement, typeSystem);
}
if (statement.statementType === "LetStatement") {
this.withLetStatement(blockCtx, statement, typeSystem);
blockCtx.environment[statement.variableName.text] = {
namedEntity: "Variable",
type: statement.type,
};
}
if (statement.statementType === "Expression") {
this.withExpression(blockCtx, statement, typeSystem);
}
if (statement.statementType === "ReturnStatement") {
this.withReturnStatement(blockCtx, statement, typeSystem);
}
if (!containsReturn(block)) {
const lastStatement = block.statements[block.statements.length - 1] ?? null;
if (lastStatement && lastStatement.statementType == "Expression") {
typeSystem.compare({
left: block.type,
operation: { operation: "equals" },
right: lastStatement.type,
});
} else {
typeSystem.compare({
left: block.type,
operation: { operation: "equals" },
right: newVoid(),
});
}
}
}
};
withAssignmentStatement = (
ctx: Context,
statement: AssignmentStatement,
typeSystem: TypeSystem,
) => {
if (statement.source.expressionType == "StructGetter") {
this.withStructGetter(ctx, statement.source, typeSystem);
}
if (statement.source.expressionType == "VariableUsage") {
this.withVariableUsage(ctx, statement.source, typeSystem);
}
this.withExpression(ctx, statement.expression, typeSystem);
typeSystem.compare({
left: statement.source.type,
operation: { operation: "equals" },
right: statement.expression.type,
});
};
withLetStatement = (ctx: Context, statement: LetStatement, typeSystem: TypeSystem) => {
this.withExpression(ctx, statement.expression, typeSystem);
typeSystem.compare({
left: statement.type,
operation: { operation: "equals" },
right: statement.expression.type,
});
};
withReturnStatement = (ctx: Context, statement: ReturnStatement, typeSystem: TypeSystem) => {
this.withExpression(ctx, statement.source, typeSystem);
if (ctx.currentFunctionReturn) {
typeSystem.compare({
left: ctx.currentFunctionReturn,
operation: { operation: "equals" },
right: statement.source.type,
});
}
};
withExpression = (ctx: Context, expression: Expression, typeSystem: TypeSystem) => {
if (expression.subExpression.expressionType === "LiteralInt") {
// LiteralInt always has type
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "LiteralFloat") {
// LiteralFloat always has type
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "LiteralString") {
// LiteralString always has type
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "LiteralBool") {
// LiteralBool always has type
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "LiteralStruct") {
this.withLiteralStruct(ctx, expression.subExpression, typeSystem);
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "FunctionCall") {
this.withFunctionCall(ctx, expression.subExpression, typeSystem);
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "VariableUsage") {
this.withVariableUsage(ctx, expression.subExpression, typeSystem);
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "IfExpression") {
this.withIfExpression(ctx, expression.subExpression, typeSystem);
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "StructGetter") {
this.withStructGetter(ctx, expression.subExpression, typeSystem);
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "Block") {
this.withBlock(ctx, expression.subExpression, typeSystem);
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
if (expression.subExpression.expressionType === "Operation") {
this.withOperation(ctx, expression.subExpression, typeSystem);
typeSystem.compare({
left: expression.type,
operation: { operation: "equals" },
right: expression.subExpression.type,
});
}
};
withLiteralStruct = (ctx: Context, literal: LiteralStruct, typeSystem: TypeSystem) => {
const definition = ctx.environment[literal.name.text];
if (!definition || definition.namedEntity !== "NamedType" || !(definition.isA === "Struct")) {
throw new Error(`${literal.name.text} not found.`);
}
if (Object.keys(definition.fields).length !== literal.fields.length) {
throw new Error(`${literal.name.text} has mismatched fields.`);
}
if (new Set(Object.keys(definition.fields)).size !== literal.fields.length) {
throw new Error(`${literal.name.text} has repeated fields.`);
}
for (const field of literal.fields) {
const definitionField = definition.fields[field.name.text];
if (!definitionField) throw new Error(`Unknown field ${field.name.text}`);
this.withExpression(ctx, field.expression, typeSystem);
typeSystem.compare({
left: definitionField,
operation: { operation: "equals" },
right: field.expression.type,
});
}
typeSystem.compare({
left: { typeUsage: "NamedTypeUsage", name: literal.name },
operation: { operation: "equals" },
right: literal.type,
});
};
withFunctionCall = (ctx: Context, fnCall: FunctionCall, typeSystem: TypeSystem) => {
this.withExpression(ctx, fnCall.source, typeSystem);
typeSystem.compare({
left: fnCall.source.type,
operation: { operation: "return" },
right: fnCall.type,
});
typeSystem.compare({
left: fnCall.source.type,
operation: { operation: "equals" },
right: {
typeUsage: "FunctionTypeUsage",
arguments: fnCall.arguments.map((arg) => arg.type),
returnType: fnCall.type,
},
});
for (const [i, arg] of fnCall.arguments.entries()) {
this.withExpression(ctx, arg, typeSystem);
typeSystem.compare({
left: fnCall.source.type,
operation: { operation: "argument", argNum: i },
right: arg.type,
});
}
};
withVariableUsage = (ctx: Context, usage: VariableUsage, typeSystem: TypeSystem) => {
const variable = ctx.environment[usage.name.text];
if (!variable || variable.namedEntity === "NamedType") {
throw new Error(`${usage.name.text} not found.`);
}
typeSystem.compare({
left: variable.type,
operation: { operation: "equals" },
right: usage.type,
});
};
withIfExpression = (ctx: Context, ifExpression: IfExpression, typeSystem: TypeSystem) => {
this.withExpression(ctx, ifExpression.condition, typeSystem);
typeSystem.compare({
left: { typeUsage: "NamedTypeUsage", name: { text: "bool", spanStart: 0, spanEnd: 0 } },
operation: { operation: "equals" },
right: ifExpression.condition.type,
});
this.withBlock(ctx, ifExpression.block, typeSystem);
typeSystem.compare({
left: ifExpression.block.type,
operation: { operation: "equals" },
right: ifExpression.type,
});
if (ifExpression.else) {
this.withBlock(ctx, ifExpression.else, typeSystem);
typeSystem.compare({
left: ifExpression.else.type,
operation: { operation: "equals" },
right: ifExpression.type,
});
}
};
withStructGetter = (ctx: Context, structGetter: StructGetter, typeSystem: TypeSystem) => {
this.withExpression(ctx, structGetter.source, typeSystem);
typeSystem.compare({
left: structGetter.source.type,
operation: { operation: "field", name: structGetter.attribute.text },
right: structGetter.type,
});
};
withOperation = (ctx: Context, op: Operation, typeSystem: TypeSystem) => {
this.withExpression(ctx, op.left, typeSystem);
this.withExpression(ctx, op.right, typeSystem);
typeSystem.compare({
left: op.left.type,
operation: { operation: "equals" },
right: op.type,
});
typeSystem.compare({
left: op.right.type,
operation: { operation: "equals" },
right: op.type,
});
};
} }

View File

@@ -1,12 +1,20 @@
import { TypeUsage } from "../parse/ast"; import { TypeUsage, UnknownTypeUsage } from "../parse/ast";
import { newContext } from "./builtins";
import { Context, getAttr } from "./context";
export const compareTypes = (typeA: TypeUsage, typeB: TypeUsage) => { export const compareTypes = (typeA: TypeUsage, typeB: TypeUsage) => {
if (typeA.typeUsage !== typeB.typeUsage) { if (typeA.typeUsage !== typeB.typeUsage) {
throw Error(`Mismatched types: ${typeA.typeUsage} ${typeB.typeUsage}`); throw Error(`Mismatched types: ${typeA.typeUsage} ${typeB.typeUsage}`);
} }
if (typeA.typeUsage == "NamedTypeUsage" && typeB.typeUsage == "NamedTypeUsage") { if (typeA.typeUsage == "NamedTypeUsage" && typeB.typeUsage == "NamedTypeUsage") {
if (typeB.name.text === "Never") {
// never matches with everything
return;
}
if (typeA.name.text !== typeB.name.text) { if (typeA.name.text !== typeB.name.text) {
throw Error(`Mismatched types: ${typeA.name.text} ${typeB.name.text}`); throw Error(
`Mismatched types: ${typeA.name.text}:${typeA.name.spanStart}:${typeA.name.spanEnd} ${typeB.name.text}:${typeB.name.spanStart}:${typeB.name.spanEnd}`,
);
} }
} }
if (typeA.typeUsage == "FunctionTypeUsage" && typeB.typeUsage == "FunctionTypeUsage") { if (typeA.typeUsage == "FunctionTypeUsage" && typeB.typeUsage == "FunctionTypeUsage") {
@@ -33,19 +41,26 @@ interface ReturnValue {
operation: "return"; operation: "return";
} }
interface Field {
operation: "field";
name: string;
}
interface Comparison { interface Comparison {
left: TypeUsage; left: TypeUsage;
operation: Equals | Argument | ReturnValue; operation: Equals | Argument | ReturnValue | Field;
right: TypeUsage; right: TypeUsage;
} }
export class TypeSystem { export class TypeSystem {
comparisons: Comparison[]; comparisons: Comparison[];
result: Record<string, TypeUsage>; result: Record<string, TypeUsage>;
context: Context;
constructor() { constructor() {
this.comparisons = []; this.comparisons = [];
this.result = {}; this.result = {};
this.context = newContext();
} }
compare = (comparison: Comparison) => { compare = (comparison: Comparison) => {
@@ -56,6 +71,7 @@ export class TypeSystem {
let foundUpdate = false; let foundUpdate = false;
let containsUnknown = false; let containsUnknown = false;
while (true) { while (true) {
foundUpdate = false;
for (const comparison of this.comparisons) { for (const comparison of this.comparisons) {
// if already found, just update // if already found, just update
comparison.left = this.resolveType(comparison.left); comparison.left = this.resolveType(comparison.left);
@@ -152,6 +168,37 @@ export class TypeSystem {
compareTypes(comparison.left.returnType, comparison.right); compareTypes(comparison.left.returnType, comparison.right);
} }
} }
// field
if (comparison.operation.operation === "field") {
if (comparison.left.typeUsage === "UnknownTypeUsage") {
// cannot yet be resolved
continue;
}
if (comparison.left.typeUsage !== "NamedTypeUsage") {
throw Error("field on something that isn't a named type.");
}
// cannot be solved left
// solve right
if (comparison.right.typeUsage === "UnknownTypeUsage") {
foundUpdate = true;
const attrType = getAttr(
this.context,
comparison.left.name.text,
comparison.operation.name,
);
this.result[comparison.right.name] = attrType;
comparison.right = attrType;
}
// check
if (comparison.right.typeUsage !== "UnknownTypeUsage") {
const attrType = getAttr(
this.context,
comparison.left.name.text,
comparison.operation.name,
);
compareTypes(attrType, comparison.right);
}
}
if ( if (
comparison.left.typeUsage === "UnknownTypeUsage" || comparison.left.typeUsage === "UnknownTypeUsage" ||
comparison.right.typeUsage === "UnknownTypeUsage" comparison.right.typeUsage === "UnknownTypeUsage"