Spaces:
Sleeping
Sleeping
| /** | |
| * Trigo Tree Agent - AI agent using tree attention for efficient move evaluation | |
| * | |
| * Uses evaluation mode ONNX model to score all valid moves in parallel. | |
| * Organizes moves as a prefix tree where branches with same head token are merged. | |
| */ | |
| import { ModelInferencer } from "./modelInferencer"; | |
| import type { EvaluationInputs } from "./modelInferencer"; | |
| import { TrigoGame, StoneType } from "./trigo/game"; | |
| import type { Move, Stone, Position } from "./trigo/types"; | |
| import { encodeAb0yz } from "./trigo/ab0yz"; | |
| export interface ScoredMove { | |
| move: Move; | |
| score: number; // Log probability | |
| notation: string; // TGN notation (e.g., "ab0") | |
| } | |
| export class TrigoTreeAgent { | |
| private inferencer: ModelInferencer; | |
| constructor(inferencer: ModelInferencer) { | |
| this.inferencer = inferencer; | |
| } | |
| /** | |
| * Convert Stone type to player string | |
| */ | |
| private stoneToPlayer(stone: Stone): "black" | "white" { | |
| if (stone === StoneType.BLACK) return "black"; | |
| if (stone === StoneType.WHITE) return "white"; | |
| throw new Error(`Invalid stone type: ${stone}`); | |
| } | |
| /** | |
| * Encode a position to TGN notation (3 characters for 5×5×5 board) | |
| */ | |
| private positionToTGN(pos: Position, shape: { x: number; y: number; z: number }): string { | |
| const posArray = [pos.x, pos.y, pos.z]; | |
| const shapeArray = [shape.x, shape.y, shape.z]; | |
| return encodeAb0yz(posArray, shapeArray); | |
| } | |
| /** | |
| * Convert string to byte tokens (ASCII encoding) | |
| */ | |
| private stringToTokens(str: string): number[] { | |
| const tokens: number[] = []; | |
| for (let i = 0; i < str.length; i++) { | |
| tokens.push(str.charCodeAt(i)); | |
| } | |
| return tokens; | |
| } | |
| /** | |
| * Build prefix tree from token arrays using recursive merging | |
| * Merges branches with the same token at EVERY level | |
| * | |
| * Algorithm: | |
| * 1. Group sequences by their first token | |
| * 2. For each group: | |
| * - Create one node for the shared first token | |
| * - Extract remaining tokens (residues) | |
| * - Recursively build subtree from residues | |
| * 3. Combine all subtrees and build attention mask | |
| * | |
| * Example for ["aa", "ab", "ba", "bb"]: | |
| * Level 1: Group by first token → 'a': ["a","b"], 'b': ["a","b"] | |
| * Level 2: Within 'a' group, build subtree for ["a","b"] | |
| * Within 'b' group, build subtree for ["a","b"] | |
| * Result: Two branches, each with properly merged second-level nodes | |
| * | |
| * @param tokenArrays - Array of token arrays | |
| * @returns Flattened token array (length m), mask matrix (m×m), and move-to-position mapping | |
| */ | |
| private buildPrefixTree(tokenArrays: number[][]): { | |
| evaluatedIds: number[]; | |
| mask: number[]; | |
| moveToLeafPos: number[]; | |
| } { | |
| type Seq = { moveIndex: number; tokens: number[] }; | |
| interface Node { | |
| token: number; | |
| pos: number; | |
| parent: number | null; | |
| children: Node[]; | |
| moveEnds: number[]; | |
| } | |
| let nextPos = 0; | |
| // --- Build prefix tree through recursive grouping --- | |
| function build(seqs: Seq[], parent: number | null): Node[] { | |
| // group by token | |
| const groups = new Map<number, Seq[]>(); | |
| for (const s of seqs) { | |
| if (s.tokens.length === 0) continue; | |
| const t = s.tokens[0]; | |
| if (!groups.has(t)) groups.set(t, []); | |
| groups.get(t)!.push(s); | |
| } | |
| const levelNodes: Node[] = []; | |
| for (const [token, group] of groups) { | |
| const pos = nextPos++; | |
| const node: Node = { | |
| token, | |
| pos, | |
| parent, | |
| children: [], | |
| moveEnds: [] | |
| }; | |
| // split residues | |
| const ends: number[] = []; | |
| const residues: Seq[] = []; | |
| for (const g of group) { | |
| if (g.tokens.length === 1) ends.push(g.moveIndex); | |
| else residues.push({ moveIndex: g.moveIndex, tokens: g.tokens.slice(1) }); | |
| } | |
| node.moveEnds = ends; | |
| // create sub nodes recursively | |
| if (residues.length > 0) { | |
| node.children = build(residues, pos); | |
| } | |
| levelNodes.push(node); | |
| } | |
| return levelNodes; | |
| } | |
| // Build roots | |
| const seqs = tokenArrays.map((t, i) => ({ moveIndex: i, tokens: t })); | |
| const roots = build(seqs, null); | |
| const total = nextPos; | |
| // --- Flatten tree --- | |
| const evaluatedIds = new Array<number>(total); | |
| const parent = new Array<number | null>(total).fill(null); | |
| const moveToLeafPos = new Array<number>(tokenArrays.length).fill(-1); | |
| function dfs(n: Node) { | |
| evaluatedIds[n.pos] = n.token; | |
| parent[n.pos] = n.parent; | |
| for (const m of n.moveEnds) moveToLeafPos[m] = n.pos; | |
| for (const c of n.children) dfs(c); | |
| } | |
| for (const r of roots) dfs(r); | |
| // --- Build ancestor mask --- | |
| const mask = new Array(total * total).fill(0); | |
| for (let i = 0; i < total; i++) { | |
| let p = i; | |
| while (p !== null) { | |
| mask[i * total + p] = 1; | |
| p = parent[p]!; | |
| } | |
| } | |
| return { evaluatedIds, mask, moveToLeafPos }; | |
| } | |
| /** | |
| * Build tree structure for all valid moves | |
| * Returns prefix tokens and tree structure for batch evaluation | |
| */ | |
| private buildMoveTree( | |
| game: TrigoGame, | |
| moves: Move[] | |
| ): { | |
| prefixTokens: number[]; | |
| evaluatedIds: number[]; | |
| mask: number[]; | |
| moveData: Array<{ move: Move; notation: string; leafPos: number; parentPos: number }>; | |
| } { | |
| // Get current TGN as prefix | |
| const currentTGN = game.toTGN().trim(); | |
| // Build prefix (everything up to next move) | |
| const lines = currentTGN.split("\n"); | |
| const lastLine = lines[lines.length - 1]; | |
| let prefix: string; | |
| if (lastLine.match(/^\d+\./)) { | |
| // Last line is a move number, include it | |
| prefix = currentTGN + " "; | |
| } else if (lastLine.trim() === "") { | |
| // Empty line, add move number | |
| const moveMatches = currentTGN.match(/\d+\.\s/g); | |
| const moveNumber = moveMatches ? moveMatches.length + 1 : 1; | |
| const isBlackTurn = game.getCurrentPlayer() === StoneType.BLACK; | |
| if (isBlackTurn) { | |
| prefix = currentTGN + `${moveNumber}. `; | |
| } else { | |
| prefix = currentTGN + " "; | |
| } | |
| } else { | |
| // Last line has moves, add space | |
| prefix = currentTGN + " "; | |
| } | |
| const prefixTokens = this.stringToTokens(prefix); | |
| // Encode each move to tokens (only first 2 tokens) | |
| const shape = game.getShape(); | |
| const movesWithTokens = moves.map((move) => { | |
| let notation: string; | |
| if (move.isPass) { | |
| notation = "Pass"; | |
| } else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) { | |
| notation = this.positionToTGN({ x: move.x, y: move.y, z: move.z }, shape); | |
| } else { | |
| throw new Error("Invalid move: missing coordinates"); | |
| } | |
| // Exclude the last token | |
| const fullTokens = this.stringToTokens(notation); | |
| const tokens = fullTokens.slice(0, fullTokens.length - 1); | |
| return { move, notation, tokens }; | |
| }); | |
| // Build prefix tree | |
| const tokenArrays = movesWithTokens.map((m) => m.tokens); | |
| const { evaluatedIds, mask, moveToLeafPos } = this.buildPrefixTree(tokenArrays); | |
| // Build move data with leaf positions and parent positions | |
| const moveData = movesWithTokens.map((m, index) => { | |
| const leafPos = moveToLeafPos[index]; | |
| // Find parent position (root position for this move) | |
| // Parent is the first token position | |
| const firstToken = m.tokens[0]; | |
| let parentPos = -1; | |
| for (let i = 0; i < evaluatedIds.length; i++) { | |
| if (evaluatedIds[i] === firstToken && i < leafPos) { | |
| // This is a potential parent | |
| // Check if it's in the same branch by checking mask | |
| // If leafPos can see position i, then i might be the parent | |
| if (mask[leafPos * evaluatedIds.length + i] === 1.0 && i !== leafPos) { | |
| // Find the closest parent (maximum index less than leafPos that leaf can see) | |
| if (i > parentPos) { | |
| parentPos = i; | |
| } | |
| } | |
| } | |
| } | |
| return { | |
| move: m.move, | |
| notation: m.notation, | |
| leafPos, | |
| parentPos | |
| }; | |
| }); | |
| return { prefixTokens, evaluatedIds, mask, moveData }; | |
| } | |
| /** | |
| * Get tree structure for visualization (public method) | |
| */ | |
| getTreeStructure( | |
| game: TrigoGame, | |
| moves: Move[] | |
| ): { | |
| evaluatedIds: number[]; | |
| mask: number[]; | |
| moveData: Array<{ move: Move; notation: string; leafPos: number; parentPos: number }>; | |
| } { | |
| return this.buildMoveTree(game, moves); | |
| } | |
| /** | |
| * Select best move using tree attention | |
| * Evaluates all valid moves in a single inference call | |
| */ | |
| async selectBestMove(game: TrigoGame): Promise<Move | null> { | |
| if (!this.inferencer.isReady()) { | |
| throw new Error("Inferencer not initialized"); | |
| } | |
| // Get current player as string | |
| const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer()); | |
| // Get all valid moves | |
| const validMoves: Move[] = game.validMovePositions().map((pos) => ({ | |
| x: pos.x, | |
| y: pos.y, | |
| z: pos.z, | |
| player: currentPlayer | |
| })); | |
| validMoves.push({ player: currentPlayer, isPass: true }); // Add pass move | |
| if (validMoves.length === 0) { | |
| return null; | |
| } | |
| // Score all moves using tree attention | |
| const scoredMoves = await this.scoreMoves(game, validMoves); | |
| // Return move with highest score | |
| if (scoredMoves.length === 0) { | |
| return null; | |
| } | |
| scoredMoves.sort((a, b) => b.score - a.score); | |
| return scoredMoves[0].move; | |
| } | |
| /** | |
| * Score all moves using tree attention (batch evaluation) | |
| */ | |
| async scoreMoves(game: TrigoGame, moves: Move[]): Promise<ScoredMove[]> { | |
| if (moves.length === 0) { | |
| return []; | |
| } | |
| // Build tree structure | |
| const { prefixTokens, evaluatedIds, mask, moveData } = this.buildMoveTree(game, moves); | |
| console.debug(`Tree structure: ${evaluatedIds.length} nodes for ${moveData.length} moves`); | |
| console.debug(`Evaluated IDs:`, evaluatedIds.map((id) => String.fromCharCode(id)).join("")); | |
| //console.debug( | |
| // `Move positions:`, | |
| // moveData.map((m) => `${m.notation}@${m.leafPos}(parent=${m.parentPos})`) | |
| //); | |
| // Prepare inputs for evaluation | |
| const inputs: EvaluationInputs = { | |
| prefixIds: prefixTokens, | |
| evaluatedIds: evaluatedIds, | |
| evaluatedMask: mask | |
| }; | |
| // Run inference | |
| const output = await this.inferencer.runEvaluationInference(inputs); | |
| const { logits, numEvaluated } = output; | |
| console.debug(`Inference output: ${numEvaluated} evaluated positions`); | |
| // Score each move by accumulating log probabilities for all tokens in the path | |
| // For each move, traverse the full path from root to leaf and sum log probabilities | |
| const scoredMoves: ScoredMove[] = []; | |
| // Cache softmax results for each output position to avoid recomputation | |
| const softmaxCache = new Map<number, Float32Array>(); | |
| const getSoftmax = (outputPos: number): Float32Array => { | |
| if (!softmaxCache.has(outputPos)) { | |
| softmaxCache.set(outputPos, this.inferencer.softmax(logits, outputPos)); | |
| } | |
| return softmaxCache.get(outputPos)!; | |
| }; | |
| for (const data of moveData) { | |
| let logProb = 0; | |
| // Reconstruct the full path from root to leaf using the mask | |
| // The mask tells us which positions each position can attend to (ancestors) | |
| // We need to find all positions from root (or first move token) to leaf | |
| const leafPos = data.leafPos; | |
| const path: number[] = [0]; | |
| // Build path by finding all ancestors that this leaf can see | |
| // Start from position 0 and find all positions up to leafPos that are in the path | |
| for (let pos = 0; pos <= leafPos; pos++) { | |
| // Check if leaf can see this position (it's an ancestor or self) | |
| if (mask[leafPos * evaluatedIds.length + pos] === 1) { | |
| path.push(pos + 1); | |
| } | |
| } | |
| //console.debug("path:", data.notation, "->", path); | |
| // Now accumulate log probabilities for all transitions in the path | |
| // For each token in the path, we need P(token[i] | context up to token[i-1]) | |
| // The logits at output position j predict the NEXT token after position j | |
| // So to get P(token at position i | context), we look at output from parent position | |
| for (let i = 0; i < path.length; i++) { | |
| const currentPos = path[i]; | |
| const currentToken = data.notation.charCodeAt(i); | |
| // Subsequent tokens: predicted from previous position | |
| // The output at prevPos predicts the token at currentPos | |
| console.assert(currentPos <= numEvaluated, `Output position ${currentPos} exceeds numEvaluated ${numEvaluated}`); | |
| if (currentPos <= numEvaluated) { | |
| const probs = getSoftmax(currentPos); | |
| const prob = probs[currentToken]; | |
| if (prob > 0) | |
| logProb += Math.log(prob); | |
| else | |
| logProb += -100; | |
| } | |
| else | |
| logProb += -100; | |
| } | |
| scoredMoves.push({ | |
| move: data.move, | |
| score: logProb, | |
| notation: data.notation | |
| }); | |
| } | |
| return scoredMoves; | |
| } | |
| } | |