trigo / trigo-web /inc /trigoTreeAgent.ts
k-l-lambda's picture
updated
502af73
/**
* Trigo Tree Agent - AI agent using tree attention for efficient move evaluation
*
* Uses evaluation mode ONNX model to score all valid moves in parallel.
* Organizes moves as a prefix tree where branches with same head token are merged.
*/
import { ModelInferencer } from "./modelInferencer";
import type { EvaluationInputs } from "./modelInferencer";
import { TrigoGame, StoneType } from "./trigo/game";
import type { Move, Stone, Position } from "./trigo/types";
import { encodeAb0yz } from "./trigo/ab0yz";
export interface ScoredMove {
move: Move;
score: number; // Log probability
notation: string; // TGN notation (e.g., "ab0")
}
export class TrigoTreeAgent {
private inferencer: ModelInferencer;
constructor(inferencer: ModelInferencer) {
this.inferencer = inferencer;
}
/**
* Convert Stone type to player string
*/
private stoneToPlayer(stone: Stone): "black" | "white" {
if (stone === StoneType.BLACK) return "black";
if (stone === StoneType.WHITE) return "white";
throw new Error(`Invalid stone type: ${stone}`);
}
/**
* Encode a position to TGN notation (3 characters for 5×5×5 board)
*/
private positionToTGN(pos: Position, shape: { x: number; y: number; z: number }): string {
const posArray = [pos.x, pos.y, pos.z];
const shapeArray = [shape.x, shape.y, shape.z];
return encodeAb0yz(posArray, shapeArray);
}
/**
* Convert string to byte tokens (ASCII encoding)
*/
private stringToTokens(str: string): number[] {
const tokens: number[] = [];
for (let i = 0; i < str.length; i++) {
tokens.push(str.charCodeAt(i));
}
return tokens;
}
/**
* Build prefix tree from token arrays using recursive merging
* Merges branches with the same token at EVERY level
*
* Algorithm:
* 1. Group sequences by their first token
* 2. For each group:
* - Create one node for the shared first token
* - Extract remaining tokens (residues)
* - Recursively build subtree from residues
* 3. Combine all subtrees and build attention mask
*
* Example for ["aa", "ab", "ba", "bb"]:
* Level 1: Group by first token → 'a': ["a","b"], 'b': ["a","b"]
* Level 2: Within 'a' group, build subtree for ["a","b"]
* Within 'b' group, build subtree for ["a","b"]
* Result: Two branches, each with properly merged second-level nodes
*
* @param tokenArrays - Array of token arrays
* @returns Flattened token array (length m), mask matrix (m×m), and move-to-position mapping
*/
private buildPrefixTree(tokenArrays: number[][]): {
evaluatedIds: number[];
mask: number[];
moveToLeafPos: number[];
} {
type Seq = { moveIndex: number; tokens: number[] };
interface Node {
token: number;
pos: number;
parent: number | null;
children: Node[];
moveEnds: number[];
}
let nextPos = 0;
// --- Build prefix tree through recursive grouping ---
function build(seqs: Seq[], parent: number | null): Node[] {
// group by token
const groups = new Map<number, Seq[]>();
for (const s of seqs) {
if (s.tokens.length === 0) continue;
const t = s.tokens[0];
if (!groups.has(t)) groups.set(t, []);
groups.get(t)!.push(s);
}
const levelNodes: Node[] = [];
for (const [token, group] of groups) {
const pos = nextPos++;
const node: Node = {
token,
pos,
parent,
children: [],
moveEnds: []
};
// split residues
const ends: number[] = [];
const residues: Seq[] = [];
for (const g of group) {
if (g.tokens.length === 1) ends.push(g.moveIndex);
else residues.push({ moveIndex: g.moveIndex, tokens: g.tokens.slice(1) });
}
node.moveEnds = ends;
// create sub nodes recursively
if (residues.length > 0) {
node.children = build(residues, pos);
}
levelNodes.push(node);
}
return levelNodes;
}
// Build roots
const seqs = tokenArrays.map((t, i) => ({ moveIndex: i, tokens: t }));
const roots = build(seqs, null);
const total = nextPos;
// --- Flatten tree ---
const evaluatedIds = new Array<number>(total);
const parent = new Array<number | null>(total).fill(null);
const moveToLeafPos = new Array<number>(tokenArrays.length).fill(-1);
function dfs(n: Node) {
evaluatedIds[n.pos] = n.token;
parent[n.pos] = n.parent;
for (const m of n.moveEnds) moveToLeafPos[m] = n.pos;
for (const c of n.children) dfs(c);
}
for (const r of roots) dfs(r);
// --- Build ancestor mask ---
const mask = new Array(total * total).fill(0);
for (let i = 0; i < total; i++) {
let p = i;
while (p !== null) {
mask[i * total + p] = 1;
p = parent[p]!;
}
}
return { evaluatedIds, mask, moveToLeafPos };
}
/**
* Build tree structure for all valid moves
* Returns prefix tokens and tree structure for batch evaluation
*/
private buildMoveTree(
game: TrigoGame,
moves: Move[]
): {
prefixTokens: number[];
evaluatedIds: number[];
mask: number[];
moveData: Array<{ move: Move; notation: string; leafPos: number; parentPos: number }>;
} {
// Get current TGN as prefix
const currentTGN = game.toTGN().trim();
// Build prefix (everything up to next move)
const lines = currentTGN.split("\n");
const lastLine = lines[lines.length - 1];
let prefix: string;
if (lastLine.match(/^\d+\./)) {
// Last line is a move number, include it
prefix = currentTGN + " ";
} else if (lastLine.trim() === "") {
// Empty line, add move number
const moveMatches = currentTGN.match(/\d+\.\s/g);
const moveNumber = moveMatches ? moveMatches.length + 1 : 1;
const isBlackTurn = game.getCurrentPlayer() === StoneType.BLACK;
if (isBlackTurn) {
prefix = currentTGN + `${moveNumber}. `;
} else {
prefix = currentTGN + " ";
}
} else {
// Last line has moves, add space
prefix = currentTGN + " ";
}
const prefixTokens = this.stringToTokens(prefix);
// Encode each move to tokens (only first 2 tokens)
const shape = game.getShape();
const movesWithTokens = moves.map((move) => {
let notation: string;
if (move.isPass) {
notation = "Pass";
} else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) {
notation = this.positionToTGN({ x: move.x, y: move.y, z: move.z }, shape);
} else {
throw new Error("Invalid move: missing coordinates");
}
// Exclude the last token
const fullTokens = this.stringToTokens(notation);
const tokens = fullTokens.slice(0, fullTokens.length - 1);
return { move, notation, tokens };
});
// Build prefix tree
const tokenArrays = movesWithTokens.map((m) => m.tokens);
const { evaluatedIds, mask, moveToLeafPos } = this.buildPrefixTree(tokenArrays);
// Build move data with leaf positions and parent positions
const moveData = movesWithTokens.map((m, index) => {
const leafPos = moveToLeafPos[index];
// Find parent position (root position for this move)
// Parent is the first token position
const firstToken = m.tokens[0];
let parentPos = -1;
for (let i = 0; i < evaluatedIds.length; i++) {
if (evaluatedIds[i] === firstToken && i < leafPos) {
// This is a potential parent
// Check if it's in the same branch by checking mask
// If leafPos can see position i, then i might be the parent
if (mask[leafPos * evaluatedIds.length + i] === 1.0 && i !== leafPos) {
// Find the closest parent (maximum index less than leafPos that leaf can see)
if (i > parentPos) {
parentPos = i;
}
}
}
}
return {
move: m.move,
notation: m.notation,
leafPos,
parentPos
};
});
return { prefixTokens, evaluatedIds, mask, moveData };
}
/**
* Get tree structure for visualization (public method)
*/
getTreeStructure(
game: TrigoGame,
moves: Move[]
): {
evaluatedIds: number[];
mask: number[];
moveData: Array<{ move: Move; notation: string; leafPos: number; parentPos: number }>;
} {
return this.buildMoveTree(game, moves);
}
/**
* Select best move using tree attention
* Evaluates all valid moves in a single inference call
*/
async selectBestMove(game: TrigoGame): Promise<Move | null> {
if (!this.inferencer.isReady()) {
throw new Error("Inferencer not initialized");
}
// Get current player as string
const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer());
// Get all valid moves
const validMoves: Move[] = game.validMovePositions().map((pos) => ({
x: pos.x,
y: pos.y,
z: pos.z,
player: currentPlayer
}));
validMoves.push({ player: currentPlayer, isPass: true }); // Add pass move
if (validMoves.length === 0) {
return null;
}
// Score all moves using tree attention
const scoredMoves = await this.scoreMoves(game, validMoves);
// Return move with highest score
if (scoredMoves.length === 0) {
return null;
}
scoredMoves.sort((a, b) => b.score - a.score);
return scoredMoves[0].move;
}
/**
* Score all moves using tree attention (batch evaluation)
*/
async scoreMoves(game: TrigoGame, moves: Move[]): Promise<ScoredMove[]> {
if (moves.length === 0) {
return [];
}
// Build tree structure
const { prefixTokens, evaluatedIds, mask, moveData } = this.buildMoveTree(game, moves);
console.debug(`Tree structure: ${evaluatedIds.length} nodes for ${moveData.length} moves`);
console.debug(`Evaluated IDs:`, evaluatedIds.map((id) => String.fromCharCode(id)).join(""));
//console.debug(
// `Move positions:`,
// moveData.map((m) => `${m.notation}@${m.leafPos}(parent=${m.parentPos})`)
//);
// Prepare inputs for evaluation
const inputs: EvaluationInputs = {
prefixIds: prefixTokens,
evaluatedIds: evaluatedIds,
evaluatedMask: mask
};
// Run inference
const output = await this.inferencer.runEvaluationInference(inputs);
const { logits, numEvaluated } = output;
console.debug(`Inference output: ${numEvaluated} evaluated positions`);
// Score each move by accumulating log probabilities for all tokens in the path
// For each move, traverse the full path from root to leaf and sum log probabilities
const scoredMoves: ScoredMove[] = [];
// Cache softmax results for each output position to avoid recomputation
const softmaxCache = new Map<number, Float32Array>();
const getSoftmax = (outputPos: number): Float32Array => {
if (!softmaxCache.has(outputPos)) {
softmaxCache.set(outputPos, this.inferencer.softmax(logits, outputPos));
}
return softmaxCache.get(outputPos)!;
};
for (const data of moveData) {
let logProb = 0;
// Reconstruct the full path from root to leaf using the mask
// The mask tells us which positions each position can attend to (ancestors)
// We need to find all positions from root (or first move token) to leaf
const leafPos = data.leafPos;
const path: number[] = [0];
// Build path by finding all ancestors that this leaf can see
// Start from position 0 and find all positions up to leafPos that are in the path
for (let pos = 0; pos <= leafPos; pos++) {
// Check if leaf can see this position (it's an ancestor or self)
if (mask[leafPos * evaluatedIds.length + pos] === 1) {
path.push(pos + 1);
}
}
//console.debug("path:", data.notation, "->", path);
// Now accumulate log probabilities for all transitions in the path
// For each token in the path, we need P(token[i] | context up to token[i-1])
// The logits at output position j predict the NEXT token after position j
// So to get P(token at position i | context), we look at output from parent position
for (let i = 0; i < path.length; i++) {
const currentPos = path[i];
const currentToken = data.notation.charCodeAt(i);
// Subsequent tokens: predicted from previous position
// The output at prevPos predicts the token at currentPos
console.assert(currentPos <= numEvaluated, `Output position ${currentPos} exceeds numEvaluated ${numEvaluated}`);
if (currentPos <= numEvaluated) {
const probs = getSoftmax(currentPos);
const prob = probs[currentToken];
if (prob > 0)
logProb += Math.log(prob);
else
logProb += -100;
}
else
logProb += -100;
}
scoredMoves.push({
move: data.move,
score: logProb,
notation: data.notation
});
}
return scoredMoves;
}
}