/** * ONNX Model Inferencer (Frontend/Backend Common) * * Platform-agnostic inference logic that accepts ONNX session from platform-specific code. * No direct dependency on onnxruntime packages - uses dependency injection pattern. * * Adapted from Node.js test_inference.js for cross-platform use * Provides causal language model inference using GPT-2 ONNX model */ /** * Minimal ONNX Tensor interface (platform-agnostic) */ export interface OnnxTensor { readonly data: number[] | Float32Array | Int32Array | BigInt64Array | Uint8Array; readonly dims: readonly number[]; readonly type: string; } /** * Minimal ONNX Session interface (platform-agnostic) */ export interface OnnxSession { readonly inputNames: readonly string[]; readonly outputNames: readonly string[]; run(feeds: Record): Promise>; } /** * Tensor constructor interface (platform-specific) */ export interface TensorConstructor { new ( type: string, data: BigInt64Array | Float32Array | Int32Array | Uint8Array, dims: number[] ): OnnxTensor; } /** * Configuration for the inferencer */ export interface InferencerConfig { vocabSize: number; seqLen: number; modelPath?: string; // Optional, for reference } /** * Inference result containing generated tokens and metadata */ export interface InferenceResult { tokens: number[]; text: string; logits: Float32Array; inferenceTime: number; } /** * Evaluation mode inputs for tree attention */ export interface EvaluationInputs { prefixIds: number[]; // Prefix sequence (causal context) evaluatedIds: number[]; // Tokens to evaluate in tree evaluatedMask: number[]; // Attention mask [m×m] flattened } /** * Evaluation mode output */ export interface EvaluationOutput { logits: Float32Array; // [m+1, vocab_size] flattened numEvaluated: number; // m } /** * Model Inferencer for Causal Language Model * Compatible with both frontend (onnxruntime-web) and backend (onnxruntime-node) */ export class ModelInferencer { private session: OnnxSession | null = null; private config: InferencerConfig; private TensorClass: TensorConstructor; // TGN tokenizer: byte-level (0-255) + PAD(256) + START(257) + END(258) private readonly PAD_TOKEN = 256; private readonly START_TOKEN = 257; private readonly END_TOKEN = 258; constructor(TensorClass: TensorConstructor, config: Partial = {}) { this.TensorClass = TensorClass; this.config = { vocabSize: 259, seqLen: 256, ...config }; } /** * Set the inference session (created by platform-specific code) */ setSession(session: OnnxSession): void { this.session = session; console.log("[ModelInferencer] ✓ Session set successfully"); this.printModelInfo(); } /** * Run basic inference test */ async testBasicInference(): Promise { if (!this.session) { throw new Error("Inferencer not initialized. Call setSession() first."); } console.log("[ModelInferencer] Running basic inference test..."); const batchSize = 1; const seqLen = this.config.seqLen; // Create random input const inputIds = this.createRandomInput(batchSize, seqLen); const inputTensor = new this.TensorClass("int64", inputIds, [batchSize, seqLen]); // Run inference const startTime = performance.now(); const results = await this.session.run({ input_ids: inputTensor }); const inferenceTime = performance.now() - startTime; // Get logits const logits = results.logits; // Validate output this.validateOutput(logits, batchSize, seqLen); // Get predictions const predictions = this.getPredictions(logits.data as Float32Array, batchSize * seqLen); // Convert tokens to text const text = String.fromCharCode(...predictions.slice(0, 100)); console.log("[ModelInferencer] Inference completed:"); console.log(` Input shape: [${inputTensor.dims.join(", ")}]`); console.log(` Output shape: [${logits.dims.join(", ")}]`); console.log(` Output dtype: ${logits.type}`); console.log(` Inference time: ${inferenceTime.toFixed(2)}ms`); console.log(` Sample predictions: [${predictions.slice(0, 10).join(", ")}]`); const logitsArray = Array.from(logits.data as Float32Array); console.log( ` Logits range: [${Math.min(...logitsArray).toFixed(3)}, ${Math.max(...logitsArray).toFixed(3)}]` ); return { tokens: predictions, text, logits: logits.data as Float32Array, inferenceTime }; } /** * Generate tokens autoregressively from a prompt */ async generateText(prompt: string, numTokens: number = 10): Promise { if (!this.session) { throw new Error("Inferencer not initialized. Call setSession() first."); } console.log(`[ModelInferencer] Generating ${numTokens} tokens from prompt: "${prompt}"`); // Convert prompt to token IDs (byte values) const promptTokens = Array.from(prompt).map((c) => c.charCodeAt(0)); console.log(` Prompt tokens (${promptTokens.length}): [${promptTokens.join(", ")}]`); // Start with prompt tokens const sequence = [...promptTokens]; const times: number[] = []; // Generate tokens for (let i = 0; i < numTokens; i++) { // Pad sequence to fixed length const paddedSequence = this.padSequence(sequence, this.config.seqLen); // Create input tensor const inputIds = new BigInt64Array(paddedSequence.map((t) => BigInt(t))); const inputTensor = new this.TensorClass("int64", inputIds, [1, this.config.seqLen]); // Run inference const startTime = performance.now(); const results = await this.session.run({ input_ids: inputTensor }); times.push(performance.now() - startTime); // Get prediction at the last non-padded position const logits = results.logits.data as Float32Array; const lastPos = sequence.length - 1; // Position before padding const offset = lastPos * this.config.vocabSize; // Find token with highest logit let maxIdx = 0; let maxVal = logits[offset]; for (let j = 1; j < this.config.vocabSize; j++) { if (logits[offset + j] > maxVal) { maxVal = logits[offset + j]; maxIdx = j; } } sequence.push(maxIdx); // Stop if END token is generated if (maxIdx === this.END_TOKEN) { console.log(" Generated END token, stopping..."); break; } } // Convert generated tokens to text const generatedText = String.fromCharCode(...sequence); const avgTime = times.reduce((a, b) => a + b, 0) / times.length; console.log(`[ModelInferencer] Generation complete:`); console.log(` Generated text: "${generatedText}"`); console.log(` Token sequence (${sequence.length}): [${sequence.join(", ")}]`); console.log(` Avg inference time: ${avgTime.toFixed(2)}ms`); console.log(` Tokens/sec: ${(1000 / avgTime).toFixed(2)}`); return { tokens: sequence, text: generatedText, logits: new Float32Array(), // Not returning full logits for generation inferenceTime: avgTime }; } /** * Get model information */ getModelInfo(): { inputs: string[]; outputs: string[] } | null { if (!this.session) return null; return { inputs: [...this.session.inputNames], outputs: [...this.session.outputNames] }; } /** * Get configuration */ getConfig(): InferencerConfig { return this.config; } /** * Run inference with token array input * Returns raw logits as Float32Array */ async runInference(tokens: number[]): Promise { if (!this.session) { throw new Error("Inferencer not initialized. Call setSession() first."); } const seqLen = this.config.seqLen; // Prepend START_TOKEN to input const tokensWithStart = [this.START_TOKEN, ...tokens]; // Pad to fixed length const paddedTokens = new BigInt64Array(seqLen); for (let i = 0; i < seqLen; i++) { paddedTokens[i] = i < tokensWithStart.length ? BigInt(tokensWithStart[i]) : BigInt(this.PAD_TOKEN); } // Create input tensor const inputTensor = new this.TensorClass("int64", paddedTokens, [1, seqLen]); // Run inference const results = await this.session.run({ input_ids: inputTensor }); return results.logits.data as Float32Array; } /** * Run tree attention inference (evaluation mode) * For models exported with --evaluation flag * @param inputs - Prefix, evaluated tokens, and attention mask * @returns Logits for each evaluated position */ async runEvaluationInference(inputs: EvaluationInputs): Promise { if (!this.session) { throw new Error("Inferencer not initialized. Call setSession() first."); } const { prefixIds, evaluatedIds, evaluatedMask } = inputs; const batchSize = 1; const prefixLen = prefixIds.length; const m = evaluatedIds.length; // Convert to BigInt64Array for ONNX int64 tensors const prefixIdsArray = new BigInt64Array(batchSize * prefixLen); for (let i = 0; i < prefixLen; i++) { prefixIdsArray[i] = BigInt(prefixIds[i]); } const evaluatedIdsArray = new BigInt64Array(batchSize * m); for (let i = 0; i < m; i++) { evaluatedIdsArray[i] = BigInt(evaluatedIds[i]); } // Mask is Float32Array const maskArray = new Float32Array(m * m); for (let i = 0; i < m * m; i++) { maskArray[i] = evaluatedMask[i]; } // Create ONNX tensors const prefixIdsTensor = new this.TensorClass("int64", prefixIdsArray, [ batchSize, prefixLen ]); const evaluatedIdsTensor = new this.TensorClass("int64", evaluatedIdsArray, [batchSize, m]); const evaluatedMaskTensor = new this.TensorClass("float32", maskArray, [1, m, m]); // Run inference const results = await this.session.run({ prefix_ids: prefixIdsTensor, evaluated_ids: evaluatedIdsTensor, evaluated_mask: evaluatedMaskTensor }); // Extract logits const logits = results.logits.data as Float32Array; // Output shape: [batch, m+1, vocab_size] // We return flattened array and num_evaluated for reshaping return { logits, numEvaluated: m }; } /** * Compute softmax for a single position's logits * @param logits - Full logits array * @param position - Which evaluated position (0 = last prefix, 1-m = evaluated tokens) * @returns Probability distribution over vocabulary */ softmax(logits: Float32Array, position: number): Float32Array { const vocabSize = this.config.vocabSize; const offset = position * vocabSize; const probs = new Float32Array(vocabSize); // Find max for numerical stability let maxLogit = -Infinity; for (let i = 0; i < vocabSize; i++) { maxLogit = Math.max(maxLogit, logits[offset + i]); } // Compute exp and sum let sumExp = 0; for (let i = 0; i < vocabSize; i++) { probs[i] = Math.exp(logits[offset + i] - maxLogit); sumExp += probs[i]; } // Normalize for (let i = 0; i < vocabSize; i++) { probs[i] /= sumExp; } return probs; } /** * Check if inferencer is ready */ isReady(): boolean { return this.session !== null; } /** * Destroy the session and free resources */ destroy(): void { this.session = null; console.log("[ModelInferencer] Session destroyed"); } // Private helper methods private printModelInfo(): void { if (!this.session) return; console.log("[ModelInferencer] Model Information:"); console.log(" Inputs:"); this.session.inputNames.forEach((name, i) => { console.log(` [${i}] ${name}`); }); console.log(" Outputs:"); this.session.outputNames.forEach((name, i) => { console.log(` [${i}] ${name}`); }); } private createRandomInput(batchSize: number, seqLen: number): BigInt64Array { const size = batchSize * seqLen; const data = new BigInt64Array(size); for (let i = 0; i < size; i++) { data[i] = BigInt(Math.floor(Math.random() * this.config.vocabSize)); } return data; } private padSequence(tokens: number[], targetLen: number): number[] { const padded = [...tokens]; while (padded.length < targetLen) { padded.push(this.PAD_TOKEN); } return padded.slice(0, targetLen); // Truncate if too long } private validateOutput(logits: OnnxTensor, batchSize: number, seqLen: number): void { const expectedShape = [batchSize, seqLen, this.config.vocabSize]; if (logits.dims.length !== 3) { throw new Error(`Expected 3D output, got ${logits.dims.length}D`); } if ( logits.dims[0] !== expectedShape[0] || logits.dims[1] !== expectedShape[1] || logits.dims[2] !== expectedShape[2] ) { throw new Error( `Shape mismatch! Expected [${expectedShape.join(", ")}], ` + `got [${logits.dims.join(", ")}]` ); } if (logits.type !== "float32") { throw new Error(`Expected float32 output, got ${logits.type}`); } } private getPredictions(logitsData: Float32Array, numPositions: number): number[] { const predictions: number[] = []; for (let i = 0; i < numPositions; i++) { let maxIdx = 0; let maxVal = logitsData[i * this.config.vocabSize]; for (let j = 1; j < this.config.vocabSize; j++) { const val = logitsData[i * this.config.vocabSize + j]; if (val > maxVal) { maxVal = val; maxIdx = j; } } predictions.push(maxIdx); } return predictions; } }