Spaces:
Running
Running
| /** | |
| * Trigo AI Agent - Language Model-based Move Selection (Frontend/Backend Common) | |
| * | |
| * Platform-agnostic AI agent that accepts ONNX session from platform-specific code. | |
| * No direct dependency on onnxruntime packages - uses dependency injection pattern. | |
| * | |
| * Uses ONNX language model to score and select moves by: | |
| * 1. Getting all valid moves for current position | |
| * 2. Scoring each move by appending it to TGN and computing token probabilities | |
| * 3. Selecting move with highest probability (argmax) | |
| */ | |
| import { ModelInferencer } from "./modelInferencer"; | |
| import { TrigoGame, StoneType } from "./trigo/game"; | |
| import type { Move, Stone } from "./trigo/types"; | |
| /** | |
| * Configuration for the AI agent | |
| */ | |
| export interface TrigoAgentConfig { | |
| vocabSize?: number; | |
| seqLen?: number; | |
| temperature?: number; | |
| } | |
| /** | |
| * Move score result | |
| */ | |
| export interface MoveScore { | |
| move: Move; | |
| score: number; | |
| logProb: number; | |
| } | |
| /** | |
| * Trigo AI Agent for move generation | |
| * Compatible with both frontend (onnxruntime-web) and backend (onnxruntime-node) | |
| */ | |
| export class TrigoAgent { | |
| private inferencer: ModelInferencer; | |
| constructor(inferencer: ModelInferencer) { | |
| this.inferencer = inferencer; | |
| } | |
| /** | |
| * Check if agent is initialized (checks if inferencer has a session) | |
| */ | |
| isInitialized(): boolean { | |
| // Agent is initialized if the inferencer has been set up | |
| return this.inferencer !== null; | |
| } | |
| /** | |
| * Convert Stone type to player string | |
| */ | |
| private stoneToPlayer(stone: Stone): "black" | "white" { | |
| if (stone === StoneType.BLACK) return "black"; | |
| if (stone === StoneType.WHITE) return "white"; | |
| throw new Error(`Invalid stone type: ${stone}`); | |
| } | |
| /** | |
| * Convert string to token IDs (byte-level encoding) | |
| */ | |
| private stringToTokens(text: string): number[] { | |
| return Array.from(text).map((char) => char.charCodeAt(0)); | |
| } | |
| /** | |
| * Compute softmax probabilities from logits | |
| */ | |
| private softmax(logits: Float32Array, vocabSize: number): Float32Array { | |
| const probs = new Float32Array(vocabSize); | |
| let maxLogit = -Infinity; | |
| // Find max for numerical stability | |
| for (let i = 0; i < vocabSize; i++) { | |
| if (logits[i] > maxLogit) { | |
| maxLogit = logits[i]; | |
| } | |
| } | |
| // Compute exp and sum | |
| let sum = 0; | |
| for (let i = 0; i < vocabSize; i++) { | |
| probs[i] = Math.exp(logits[i] - maxLogit); | |
| sum += probs[i]; | |
| } | |
| // Normalize | |
| for (let i = 0; i < vocabSize; i++) { | |
| probs[i] /= sum; | |
| } | |
| return probs; | |
| } | |
| /** | |
| * Score a candidate move by computing token probabilities | |
| * | |
| * Clones the game, applies the move, generates new TGN, and computes | |
| * the probability of the move tokens. | |
| */ | |
| async scoreMove(game: TrigoGame, move: Move): Promise<number> { | |
| // Clone the game | |
| const clonedGame = game.clone(); | |
| // Apply the move to the cloned game | |
| let success: boolean; | |
| if (move.isPass) { | |
| success = clonedGame.pass(); | |
| } else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) { | |
| success = clonedGame.drop({ x: move.x, y: move.y, z: move.z }); | |
| } else { | |
| // Invalid move format | |
| return -1000; | |
| } | |
| if (!success) { | |
| // Invalid move, return very low probability | |
| return -1000; | |
| } | |
| // Generate TGN from both original and cloned game | |
| const newTGN = clonedGame.toTGN().trim(); | |
| // Extract the move substring | |
| // The move should be the new content added after the current TGN | |
| const moveTokens = this.extractMoveTokens(newTGN); | |
| if (moveTokens.length === 0) { | |
| // Could not extract move, return low probability | |
| return -100; | |
| } | |
| // Convert new TGN to tokens | |
| const tokens = this.stringToTokens(newTGN); | |
| // Get configuration | |
| const config = this.inferencer.getConfig(); | |
| const seqLen = config.seqLen; | |
| const vocabSize = config.vocabSize; | |
| // Truncate if too long | |
| if (tokens.length > seqLen) { | |
| tokens.splice(0, tokens.length - seqLen); | |
| } | |
| // Run inference (START_TOKEN will be prepended by inferencer) | |
| const logits = await this.inferencer.runInference(tokens); | |
| // Compute probability for the move tokens | |
| // Note: inferencer prepends START_TOKEN, so positions are offset by +1 | |
| // Token sequence: [START_TOKEN, ...tokens, PAD, PAD, ...] | |
| // Position in output: token_i is at position i+1 in the padded sequence | |
| // Find where move tokens start in original token sequence | |
| const moveStartInTokens = tokens.length - moveTokens.length; | |
| let logProb = 0; | |
| for (let i = 0; i < moveTokens.length; i++) { | |
| // Position of this move token in the original tokens array | |
| const tokenPos = moveStartInTokens + i; | |
| // Skip if position is out of bounds | |
| // Logits at position tokenPos predict the token at tokenPos+1 | |
| if (tokenPos < 0 || tokenPos >= tokens.length) continue; | |
| const offset = tokenPos * vocabSize; | |
| const tokenLogits = logits.slice(offset, offset + vocabSize); | |
| const probs = this.softmax(tokenLogits, vocabSize); | |
| const tokenId = moveTokens[i]; | |
| const prob = probs[tokenId]; | |
| if (prob > 0) { | |
| logProb += Math.log(prob); | |
| } else { | |
| // If probability is zero, assign very low prob | |
| logProb += -100; | |
| } | |
| } | |
| return logProb; | |
| } | |
| /** | |
| * Extract move tokens from TGN difference | |
| * Returns the tokens that were added between currentTGN and newTGN | |
| */ | |
| private extractMoveTokens(tgn: string): number[] { | |
| const moveCapture = tgn.match(/[Pa-z0]+$/); | |
| return this.stringToTokens(moveCapture ? moveCapture[0] : ""); | |
| } | |
| /** | |
| * Select the best move using the language model | |
| * | |
| * Scores all valid moves and returns the one with highest probability (argmax). | |
| */ | |
| async selectBestMove(game: TrigoGame): Promise<Move | null> { | |
| if (!this.isInitialized()) { | |
| throw new Error("Agent not initialized. Pass initialized inferencer to constructor."); | |
| } | |
| console.log("[TrigoAgent] Selecting move..."); | |
| // Get current player as string | |
| const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer()); | |
| // Get all valid moves | |
| const validMoves: Move[] = game.validMovePositions().map((pos) => ({ | |
| x: pos.x, | |
| y: pos.y, | |
| z: pos.z, | |
| player: currentPlayer | |
| })); | |
| validMoves.push({ player: currentPlayer, isPass: true }); // Add pass move | |
| if (validMoves.length === 0) { | |
| console.log("[TrigoAgent] No valid moves available"); | |
| return null; | |
| } | |
| console.log(`[TrigoAgent] Evaluating ${validMoves.length} valid moves...`); | |
| // Score each move | |
| const scores: MoveScore[] = []; | |
| for (const move of validMoves) { | |
| const logProb = await this.scoreMove(game, move); | |
| scores.push({ | |
| move, | |
| score: Math.exp(logProb), // Convert log prob to probability | |
| logProb | |
| }); | |
| } | |
| // Find best move (argmax) | |
| scores.sort((a, b) => b.logProb - a.logProb); | |
| const bestMove = scores[0]; | |
| console.debug("scores:", scores); | |
| console.log("[TrigoAgent] Best move:", bestMove.move, "score:", bestMove.score.toFixed(6)); | |
| console.log("[TrigoAgent] Top 5 moves:"); | |
| for (let i = 0; i < Math.min(5, scores.length); i++) { | |
| console.log(` ${i + 1}. ${scores[i].move}: ${scores[i].score.toFixed(6)}`); | |
| } | |
| return bestMove.move; | |
| } | |
| /** | |
| * Clean up resources | |
| */ | |
| destroy(): void { | |
| this.inferencer.destroy(); | |
| console.log("[TrigoAgent] Destroyed"); | |
| } | |
| } | |