trigo / trigo-web /inc /trigoAgent.ts
k-l-lambda's picture
updated
502af73
/**
* Trigo AI Agent - Language Model-based Move Selection (Frontend/Backend Common)
*
* Platform-agnostic AI agent that accepts ONNX session from platform-specific code.
* No direct dependency on onnxruntime packages - uses dependency injection pattern.
*
* Uses ONNX language model to score and select moves by:
* 1. Getting all valid moves for current position
* 2. Scoring each move by appending it to TGN and computing token probabilities
* 3. Selecting move with highest probability (argmax)
*/
import { ModelInferencer } from "./modelInferencer";
import { TrigoGame, StoneType } from "./trigo/game";
import type { Move, Stone } from "./trigo/types";
/**
* Configuration for the AI agent
*/
export interface TrigoAgentConfig {
vocabSize?: number;
seqLen?: number;
temperature?: number;
}
/**
* Move score result
*/
export interface MoveScore {
move: Move;
score: number;
logProb: number;
}
/**
* Trigo AI Agent for move generation
* Compatible with both frontend (onnxruntime-web) and backend (onnxruntime-node)
*/
export class TrigoAgent {
private inferencer: ModelInferencer;
constructor(inferencer: ModelInferencer) {
this.inferencer = inferencer;
}
/**
* Check if agent is initialized (checks if inferencer has a session)
*/
isInitialized(): boolean {
// Agent is initialized if the inferencer has been set up
return this.inferencer !== null;
}
/**
* Convert Stone type to player string
*/
private stoneToPlayer(stone: Stone): "black" | "white" {
if (stone === StoneType.BLACK) return "black";
if (stone === StoneType.WHITE) return "white";
throw new Error(`Invalid stone type: ${stone}`);
}
/**
* Convert string to token IDs (byte-level encoding)
*/
private stringToTokens(text: string): number[] {
return Array.from(text).map((char) => char.charCodeAt(0));
}
/**
* Compute softmax probabilities from logits
*/
private softmax(logits: Float32Array, vocabSize: number): Float32Array {
const probs = new Float32Array(vocabSize);
let maxLogit = -Infinity;
// Find max for numerical stability
for (let i = 0; i < vocabSize; i++) {
if (logits[i] > maxLogit) {
maxLogit = logits[i];
}
}
// Compute exp and sum
let sum = 0;
for (let i = 0; i < vocabSize; i++) {
probs[i] = Math.exp(logits[i] - maxLogit);
sum += probs[i];
}
// Normalize
for (let i = 0; i < vocabSize; i++) {
probs[i] /= sum;
}
return probs;
}
/**
* Score a candidate move by computing token probabilities
*
* Clones the game, applies the move, generates new TGN, and computes
* the probability of the move tokens.
*/
async scoreMove(game: TrigoGame, move: Move): Promise<number> {
// Clone the game
const clonedGame = game.clone();
// Apply the move to the cloned game
let success: boolean;
if (move.isPass) {
success = clonedGame.pass();
} else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) {
success = clonedGame.drop({ x: move.x, y: move.y, z: move.z });
} else {
// Invalid move format
return -1000;
}
if (!success) {
// Invalid move, return very low probability
return -1000;
}
// Generate TGN from both original and cloned game
const newTGN = clonedGame.toTGN().trim();
// Extract the move substring
// The move should be the new content added after the current TGN
const moveTokens = this.extractMoveTokens(newTGN);
if (moveTokens.length === 0) {
// Could not extract move, return low probability
return -100;
}
// Convert new TGN to tokens
const tokens = this.stringToTokens(newTGN);
// Get configuration
const config = this.inferencer.getConfig();
const seqLen = config.seqLen;
const vocabSize = config.vocabSize;
// Truncate if too long
if (tokens.length > seqLen) {
tokens.splice(0, tokens.length - seqLen);
}
// Run inference (START_TOKEN will be prepended by inferencer)
const logits = await this.inferencer.runInference(tokens);
// Compute probability for the move tokens
// Note: inferencer prepends START_TOKEN, so positions are offset by +1
// Token sequence: [START_TOKEN, ...tokens, PAD, PAD, ...]
// Position in output: token_i is at position i+1 in the padded sequence
// Find where move tokens start in original token sequence
const moveStartInTokens = tokens.length - moveTokens.length;
let logProb = 0;
for (let i = 0; i < moveTokens.length; i++) {
// Position of this move token in the original tokens array
const tokenPos = moveStartInTokens + i;
// Skip if position is out of bounds
// Logits at position tokenPos predict the token at tokenPos+1
if (tokenPos < 0 || tokenPos >= tokens.length) continue;
const offset = tokenPos * vocabSize;
const tokenLogits = logits.slice(offset, offset + vocabSize);
const probs = this.softmax(tokenLogits, vocabSize);
const tokenId = moveTokens[i];
const prob = probs[tokenId];
if (prob > 0) {
logProb += Math.log(prob);
} else {
// If probability is zero, assign very low prob
logProb += -100;
}
}
return logProb;
}
/**
* Extract move tokens from TGN difference
* Returns the tokens that were added between currentTGN and newTGN
*/
private extractMoveTokens(tgn: string): number[] {
const moveCapture = tgn.match(/[Pa-z0]+$/);
return this.stringToTokens(moveCapture ? moveCapture[0] : "");
}
/**
* Select the best move using the language model
*
* Scores all valid moves and returns the one with highest probability (argmax).
*/
async selectBestMove(game: TrigoGame): Promise<Move | null> {
if (!this.isInitialized()) {
throw new Error("Agent not initialized. Pass initialized inferencer to constructor.");
}
console.log("[TrigoAgent] Selecting move...");
// Get current player as string
const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer());
// Get all valid moves
const validMoves: Move[] = game.validMovePositions().map((pos) => ({
x: pos.x,
y: pos.y,
z: pos.z,
player: currentPlayer
}));
validMoves.push({ player: currentPlayer, isPass: true }); // Add pass move
if (validMoves.length === 0) {
console.log("[TrigoAgent] No valid moves available");
return null;
}
console.log(`[TrigoAgent] Evaluating ${validMoves.length} valid moves...`);
// Score each move
const scores: MoveScore[] = [];
for (const move of validMoves) {
const logProb = await this.scoreMove(game, move);
scores.push({
move,
score: Math.exp(logProb), // Convert log prob to probability
logProb
});
}
// Find best move (argmax)
scores.sort((a, b) => b.logProb - a.logProb);
const bestMove = scores[0];
console.debug("scores:", scores);
console.log("[TrigoAgent] Best move:", bestMove.move, "score:", bestMove.score.toFixed(6));
console.log("[TrigoAgent] Top 5 moves:");
for (let i = 0; i < Math.min(5, scores.length); i++) {
console.log(` ${i + 1}. ${scores[i].move}: ${scores[i].score.toFixed(6)}`);
}
return bestMove.move;
}
/**
* Clean up resources
*/
destroy(): void {
this.inferencer.destroy();
console.log("[TrigoAgent] Destroyed");
}
}