/** * Trigo Tree Agent - AI agent using tree attention for efficient move evaluation * * Uses evaluation mode ONNX model to score all valid moves in parallel. * Organizes moves as a prefix tree where branches with same head token are merged. */ import { ModelInferencer } from "./modelInferencer"; import type { EvaluationInputs } from "./modelInferencer"; import { TrigoGame, StoneType } from "./trigo/game"; import type { Move, Stone, Position } from "./trigo/types"; import { encodeAb0yz } from "./trigo/ab0yz"; export interface ScoredMove { move: Move; score: number; // Log probability notation: string; // TGN notation (e.g., "ab0") } export class TrigoTreeAgent { private inferencer: ModelInferencer; constructor(inferencer: ModelInferencer) { this.inferencer = inferencer; } /** * Convert Stone type to player string */ private stoneToPlayer(stone: Stone): "black" | "white" { if (stone === StoneType.BLACK) return "black"; if (stone === StoneType.WHITE) return "white"; throw new Error(`Invalid stone type: ${stone}`); } /** * Encode a position to TGN notation (3 characters for 5×5×5 board) */ private positionToTGN(pos: Position, shape: { x: number; y: number; z: number }): string { const posArray = [pos.x, pos.y, pos.z]; const shapeArray = [shape.x, shape.y, shape.z]; return encodeAb0yz(posArray, shapeArray); } /** * Convert string to byte tokens (ASCII encoding) */ private stringToTokens(str: string): number[] { const tokens: number[] = []; for (let i = 0; i < str.length; i++) { tokens.push(str.charCodeAt(i)); } return tokens; } /** * Build prefix tree from token arrays using recursive merging * Merges branches with the same token at EVERY level * * Algorithm: * 1. Group sequences by their first token * 2. For each group: * - Create one node for the shared first token * - Extract remaining tokens (residues) * - Recursively build subtree from residues * 3. Combine all subtrees and build attention mask * * Example for ["aa", "ab", "ba", "bb"]: * Level 1: Group by first token → 'a': ["a","b"], 'b': ["a","b"] * Level 2: Within 'a' group, build subtree for ["a","b"] * Within 'b' group, build subtree for ["a","b"] * Result: Two branches, each with properly merged second-level nodes * * @param tokenArrays - Array of token arrays * @returns Flattened token array (length m), mask matrix (m×m), and move-to-position mapping */ private buildPrefixTree(tokenArrays: number[][]): { evaluatedIds: number[]; mask: number[]; moveToLeafPos: number[]; } { type Seq = { moveIndex: number; tokens: number[] }; interface Node { token: number; pos: number; parent: number | null; children: Node[]; moveEnds: number[]; } let nextPos = 0; // --- Build prefix tree through recursive grouping --- function build(seqs: Seq[], parent: number | null): Node[] { // group by token const groups = new Map(); for (const s of seqs) { if (s.tokens.length === 0) continue; const t = s.tokens[0]; if (!groups.has(t)) groups.set(t, []); groups.get(t)!.push(s); } const levelNodes: Node[] = []; for (const [token, group] of groups) { const pos = nextPos++; const node: Node = { token, pos, parent, children: [], moveEnds: [] }; // split residues const ends: number[] = []; const residues: Seq[] = []; for (const g of group) { if (g.tokens.length === 1) ends.push(g.moveIndex); else residues.push({ moveIndex: g.moveIndex, tokens: g.tokens.slice(1) }); } node.moveEnds = ends; // create sub nodes recursively if (residues.length > 0) { node.children = build(residues, pos); } levelNodes.push(node); } return levelNodes; } // Build roots const seqs = tokenArrays.map((t, i) => ({ moveIndex: i, tokens: t })); const roots = build(seqs, null); const total = nextPos; // --- Flatten tree --- const evaluatedIds = new Array(total); const parent = new Array(total).fill(null); const moveToLeafPos = new Array(tokenArrays.length).fill(-1); function dfs(n: Node) { evaluatedIds[n.pos] = n.token; parent[n.pos] = n.parent; for (const m of n.moveEnds) moveToLeafPos[m] = n.pos; for (const c of n.children) dfs(c); } for (const r of roots) dfs(r); // --- Build ancestor mask --- const mask = new Array(total * total).fill(0); for (let i = 0; i < total; i++) { let p = i; while (p !== null) { mask[i * total + p] = 1; p = parent[p]!; } } return { evaluatedIds, mask, moveToLeafPos }; } /** * Build tree structure for all valid moves * Returns prefix tokens and tree structure for batch evaluation */ private buildMoveTree( game: TrigoGame, moves: Move[] ): { prefixTokens: number[]; evaluatedIds: number[]; mask: number[]; moveData: Array<{ move: Move; notation: string; leafPos: number; parentPos: number }>; } { // Get current TGN as prefix const currentTGN = game.toTGN().trim(); // Build prefix (everything up to next move) const lines = currentTGN.split("\n"); const lastLine = lines[lines.length - 1]; let prefix: string; if (lastLine.match(/^\d+\./)) { // Last line is a move number, include it prefix = currentTGN + " "; } else if (lastLine.trim() === "") { // Empty line, add move number const moveMatches = currentTGN.match(/\d+\.\s/g); const moveNumber = moveMatches ? moveMatches.length + 1 : 1; const isBlackTurn = game.getCurrentPlayer() === StoneType.BLACK; if (isBlackTurn) { prefix = currentTGN + `${moveNumber}. `; } else { prefix = currentTGN + " "; } } else { // Last line has moves, add space prefix = currentTGN + " "; } const prefixTokens = this.stringToTokens(prefix); // Encode each move to tokens (only first 2 tokens) const shape = game.getShape(); const movesWithTokens = moves.map((move) => { let notation: string; if (move.isPass) { notation = "Pass"; } else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) { notation = this.positionToTGN({ x: move.x, y: move.y, z: move.z }, shape); } else { throw new Error("Invalid move: missing coordinates"); } // Exclude the last token const fullTokens = this.stringToTokens(notation); const tokens = fullTokens.slice(0, fullTokens.length - 1); return { move, notation, tokens }; }); // Build prefix tree const tokenArrays = movesWithTokens.map((m) => m.tokens); const { evaluatedIds, mask, moveToLeafPos } = this.buildPrefixTree(tokenArrays); // Build move data with leaf positions and parent positions const moveData = movesWithTokens.map((m, index) => { const leafPos = moveToLeafPos[index]; // Find parent position (root position for this move) // Parent is the first token position const firstToken = m.tokens[0]; let parentPos = -1; for (let i = 0; i < evaluatedIds.length; i++) { if (evaluatedIds[i] === firstToken && i < leafPos) { // This is a potential parent // Check if it's in the same branch by checking mask // If leafPos can see position i, then i might be the parent if (mask[leafPos * evaluatedIds.length + i] === 1.0 && i !== leafPos) { // Find the closest parent (maximum index less than leafPos that leaf can see) if (i > parentPos) { parentPos = i; } } } } return { move: m.move, notation: m.notation, leafPos, parentPos }; }); return { prefixTokens, evaluatedIds, mask, moveData }; } /** * Get tree structure for visualization (public method) */ getTreeStructure( game: TrigoGame, moves: Move[] ): { evaluatedIds: number[]; mask: number[]; moveData: Array<{ move: Move; notation: string; leafPos: number; parentPos: number }>; } { return this.buildMoveTree(game, moves); } /** * Select best move using tree attention * Evaluates all valid moves in a single inference call */ async selectBestMove(game: TrigoGame): Promise { if (!this.inferencer.isReady()) { throw new Error("Inferencer not initialized"); } // Get current player as string const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer()); // Get all valid moves const validMoves: Move[] = game.validMovePositions().map((pos) => ({ x: pos.x, y: pos.y, z: pos.z, player: currentPlayer })); validMoves.push({ player: currentPlayer, isPass: true }); // Add pass move if (validMoves.length === 0) { return null; } // Score all moves using tree attention const scoredMoves = await this.scoreMoves(game, validMoves); // Return move with highest score if (scoredMoves.length === 0) { return null; } scoredMoves.sort((a, b) => b.score - a.score); return scoredMoves[0].move; } /** * Score all moves using tree attention (batch evaluation) */ async scoreMoves(game: TrigoGame, moves: Move[]): Promise { if (moves.length === 0) { return []; } // Build tree structure const { prefixTokens, evaluatedIds, mask, moveData } = this.buildMoveTree(game, moves); console.debug(`Tree structure: ${evaluatedIds.length} nodes for ${moveData.length} moves`); console.debug(`Evaluated IDs:`, evaluatedIds.map((id) => String.fromCharCode(id)).join("")); //console.debug( // `Move positions:`, // moveData.map((m) => `${m.notation}@${m.leafPos}(parent=${m.parentPos})`) //); // Prepare inputs for evaluation const inputs: EvaluationInputs = { prefixIds: prefixTokens, evaluatedIds: evaluatedIds, evaluatedMask: mask }; // Run inference const output = await this.inferencer.runEvaluationInference(inputs); const { logits, numEvaluated } = output; console.debug(`Inference output: ${numEvaluated} evaluated positions`); // Score each move by accumulating log probabilities for all tokens in the path // For each move, traverse the full path from root to leaf and sum log probabilities const scoredMoves: ScoredMove[] = []; // Cache softmax results for each output position to avoid recomputation const softmaxCache = new Map(); const getSoftmax = (outputPos: number): Float32Array => { if (!softmaxCache.has(outputPos)) { softmaxCache.set(outputPos, this.inferencer.softmax(logits, outputPos)); } return softmaxCache.get(outputPos)!; }; for (const data of moveData) { let logProb = 0; // Reconstruct the full path from root to leaf using the mask // The mask tells us which positions each position can attend to (ancestors) // We need to find all positions from root (or first move token) to leaf const leafPos = data.leafPos; const path: number[] = [0]; // Build path by finding all ancestors that this leaf can see // Start from position 0 and find all positions up to leafPos that are in the path for (let pos = 0; pos <= leafPos; pos++) { // Check if leaf can see this position (it's an ancestor or self) if (mask[leafPos * evaluatedIds.length + pos] === 1) { path.push(pos + 1); } } //console.debug("path:", data.notation, "->", path); // Now accumulate log probabilities for all transitions in the path // For each token in the path, we need P(token[i] | context up to token[i-1]) // The logits at output position j predict the NEXT token after position j // So to get P(token at position i | context), we look at output from parent position for (let i = 0; i < path.length; i++) { const currentPos = path[i]; const currentToken = data.notation.charCodeAt(i); // Subsequent tokens: predicted from previous position // The output at prevPos predicts the token at currentPos console.assert(currentPos <= numEvaluated, `Output position ${currentPos} exceeds numEvaluated ${numEvaluated}`); if (currentPos <= numEvaluated) { const probs = getSoftmax(currentPos); const prob = probs[currentToken]; if (prob > 0) logProb += Math.log(prob); else logProb += -100; } else logProb += -100; } scoredMoves.push({ move: data.move, score: logProb, notation: data.notation }); } return scoredMoves; } }