trigo / trigo-web /inc /modelInferencer.ts
k-l-lambda's picture
updated
502af73
/**
* ONNX Model Inferencer (Frontend/Backend Common)
*
* Platform-agnostic inference logic that accepts ONNX session from platform-specific code.
* No direct dependency on onnxruntime packages - uses dependency injection pattern.
*
* Adapted from Node.js test_inference.js for cross-platform use
* Provides causal language model inference using GPT-2 ONNX model
*/
/**
* Minimal ONNX Tensor interface (platform-agnostic)
*/
export interface OnnxTensor {
readonly data: number[] | Float32Array | Int32Array | BigInt64Array | Uint8Array;
readonly dims: readonly number[];
readonly type: string;
}
/**
* Minimal ONNX Session interface (platform-agnostic)
*/
export interface OnnxSession {
readonly inputNames: readonly string[];
readonly outputNames: readonly string[];
run(feeds: Record<string, OnnxTensor>): Promise<Record<string, OnnxTensor>>;
}
/**
* Tensor constructor interface (platform-specific)
*/
export interface TensorConstructor {
new (
type: string,
data: BigInt64Array | Float32Array | Int32Array | Uint8Array,
dims: number[]
): OnnxTensor;
}
/**
* Configuration for the inferencer
*/
export interface InferencerConfig {
vocabSize: number;
seqLen: number;
modelPath?: string; // Optional, for reference
}
/**
* Inference result containing generated tokens and metadata
*/
export interface InferenceResult {
tokens: number[];
text: string;
logits: Float32Array;
inferenceTime: number;
}
/**
* Evaluation mode inputs for tree attention
*/
export interface EvaluationInputs {
prefixIds: number[]; // Prefix sequence (causal context)
evaluatedIds: number[]; // Tokens to evaluate in tree
evaluatedMask: number[]; // Attention mask [m×m] flattened
}
/**
* Evaluation mode output
*/
export interface EvaluationOutput {
logits: Float32Array; // [m+1, vocab_size] flattened
numEvaluated: number; // m
}
/**
* Model Inferencer for Causal Language Model
* Compatible with both frontend (onnxruntime-web) and backend (onnxruntime-node)
*/
export class ModelInferencer {
private session: OnnxSession | null = null;
private config: InferencerConfig;
private TensorClass: TensorConstructor;
// TGN tokenizer: byte-level (0-255) + PAD(256) + START(257) + END(258)
private readonly PAD_TOKEN = 256;
private readonly START_TOKEN = 257;
private readonly END_TOKEN = 258;
constructor(TensorClass: TensorConstructor, config: Partial<InferencerConfig> = {}) {
this.TensorClass = TensorClass;
this.config = {
vocabSize: 259,
seqLen: 256,
...config
};
}
/**
* Set the inference session (created by platform-specific code)
*/
setSession(session: OnnxSession): void {
this.session = session;
console.log("[ModelInferencer] ✓ Session set successfully");
this.printModelInfo();
}
/**
* Run basic inference test
*/
async testBasicInference(): Promise<InferenceResult> {
if (!this.session) {
throw new Error("Inferencer not initialized. Call setSession() first.");
}
console.log("[ModelInferencer] Running basic inference test...");
const batchSize = 1;
const seqLen = this.config.seqLen;
// Create random input
const inputIds = this.createRandomInput(batchSize, seqLen);
const inputTensor = new this.TensorClass("int64", inputIds, [batchSize, seqLen]);
// Run inference
const startTime = performance.now();
const results = await this.session.run({ input_ids: inputTensor });
const inferenceTime = performance.now() - startTime;
// Get logits
const logits = results.logits;
// Validate output
this.validateOutput(logits, batchSize, seqLen);
// Get predictions
const predictions = this.getPredictions(logits.data as Float32Array, batchSize * seqLen);
// Convert tokens to text
const text = String.fromCharCode(...predictions.slice(0, 100));
console.log("[ModelInferencer] Inference completed:");
console.log(` Input shape: [${inputTensor.dims.join(", ")}]`);
console.log(` Output shape: [${logits.dims.join(", ")}]`);
console.log(` Output dtype: ${logits.type}`);
console.log(` Inference time: ${inferenceTime.toFixed(2)}ms`);
console.log(` Sample predictions: [${predictions.slice(0, 10).join(", ")}]`);
const logitsArray = Array.from(logits.data as Float32Array);
console.log(
` Logits range: [${Math.min(...logitsArray).toFixed(3)}, ${Math.max(...logitsArray).toFixed(3)}]`
);
return {
tokens: predictions,
text,
logits: logits.data as Float32Array,
inferenceTime
};
}
/**
* Generate tokens autoregressively from a prompt
*/
async generateText(prompt: string, numTokens: number = 10): Promise<InferenceResult> {
if (!this.session) {
throw new Error("Inferencer not initialized. Call setSession() first.");
}
console.log(`[ModelInferencer] Generating ${numTokens} tokens from prompt: "${prompt}"`);
// Convert prompt to token IDs (byte values)
const promptTokens = Array.from(prompt).map((c) => c.charCodeAt(0));
console.log(` Prompt tokens (${promptTokens.length}): [${promptTokens.join(", ")}]`);
// Start with prompt tokens
const sequence = [...promptTokens];
const times: number[] = [];
// Generate tokens
for (let i = 0; i < numTokens; i++) {
// Pad sequence to fixed length
const paddedSequence = this.padSequence(sequence, this.config.seqLen);
// Create input tensor
const inputIds = new BigInt64Array(paddedSequence.map((t) => BigInt(t)));
const inputTensor = new this.TensorClass("int64", inputIds, [1, this.config.seqLen]);
// Run inference
const startTime = performance.now();
const results = await this.session.run({ input_ids: inputTensor });
times.push(performance.now() - startTime);
// Get prediction at the last non-padded position
const logits = results.logits.data as Float32Array;
const lastPos = sequence.length - 1; // Position before padding
const offset = lastPos * this.config.vocabSize;
// Find token with highest logit
let maxIdx = 0;
let maxVal = logits[offset];
for (let j = 1; j < this.config.vocabSize; j++) {
if (logits[offset + j] > maxVal) {
maxVal = logits[offset + j];
maxIdx = j;
}
}
sequence.push(maxIdx);
// Stop if END token is generated
if (maxIdx === this.END_TOKEN) {
console.log(" Generated END token, stopping...");
break;
}
}
// Convert generated tokens to text
const generatedText = String.fromCharCode(...sequence);
const avgTime = times.reduce((a, b) => a + b, 0) / times.length;
console.log(`[ModelInferencer] Generation complete:`);
console.log(` Generated text: "${generatedText}"`);
console.log(` Token sequence (${sequence.length}): [${sequence.join(", ")}]`);
console.log(` Avg inference time: ${avgTime.toFixed(2)}ms`);
console.log(` Tokens/sec: ${(1000 / avgTime).toFixed(2)}`);
return {
tokens: sequence,
text: generatedText,
logits: new Float32Array(), // Not returning full logits for generation
inferenceTime: avgTime
};
}
/**
* Get model information
*/
getModelInfo(): { inputs: string[]; outputs: string[] } | null {
if (!this.session) return null;
return {
inputs: [...this.session.inputNames],
outputs: [...this.session.outputNames]
};
}
/**
* Get configuration
*/
getConfig(): InferencerConfig {
return this.config;
}
/**
* Run inference with token array input
* Returns raw logits as Float32Array
*/
async runInference(tokens: number[]): Promise<Float32Array> {
if (!this.session) {
throw new Error("Inferencer not initialized. Call setSession() first.");
}
const seqLen = this.config.seqLen;
// Prepend START_TOKEN to input
const tokensWithStart = [this.START_TOKEN, ...tokens];
// Pad to fixed length
const paddedTokens = new BigInt64Array(seqLen);
for (let i = 0; i < seqLen; i++) {
paddedTokens[i] =
i < tokensWithStart.length ? BigInt(tokensWithStart[i]) : BigInt(this.PAD_TOKEN);
}
// Create input tensor
const inputTensor = new this.TensorClass("int64", paddedTokens, [1, seqLen]);
// Run inference
const results = await this.session.run({ input_ids: inputTensor });
return results.logits.data as Float32Array;
}
/**
* Run tree attention inference (evaluation mode)
* For models exported with --evaluation flag
* @param inputs - Prefix, evaluated tokens, and attention mask
* @returns Logits for each evaluated position
*/
async runEvaluationInference(inputs: EvaluationInputs): Promise<EvaluationOutput> {
if (!this.session) {
throw new Error("Inferencer not initialized. Call setSession() first.");
}
const { prefixIds, evaluatedIds, evaluatedMask } = inputs;
const batchSize = 1;
const prefixLen = prefixIds.length;
const m = evaluatedIds.length;
// Convert to BigInt64Array for ONNX int64 tensors
const prefixIdsArray = new BigInt64Array(batchSize * prefixLen);
for (let i = 0; i < prefixLen; i++) {
prefixIdsArray[i] = BigInt(prefixIds[i]);
}
const evaluatedIdsArray = new BigInt64Array(batchSize * m);
for (let i = 0; i < m; i++) {
evaluatedIdsArray[i] = BigInt(evaluatedIds[i]);
}
// Mask is Float32Array
const maskArray = new Float32Array(m * m);
for (let i = 0; i < m * m; i++) {
maskArray[i] = evaluatedMask[i];
}
// Create ONNX tensors
const prefixIdsTensor = new this.TensorClass("int64", prefixIdsArray, [
batchSize,
prefixLen
]);
const evaluatedIdsTensor = new this.TensorClass("int64", evaluatedIdsArray, [batchSize, m]);
const evaluatedMaskTensor = new this.TensorClass("float32", maskArray, [1, m, m]);
// Run inference
const results = await this.session.run({
prefix_ids: prefixIdsTensor,
evaluated_ids: evaluatedIdsTensor,
evaluated_mask: evaluatedMaskTensor
});
// Extract logits
const logits = results.logits.data as Float32Array;
// Output shape: [batch, m+1, vocab_size]
// We return flattened array and num_evaluated for reshaping
return {
logits,
numEvaluated: m
};
}
/**
* Compute softmax for a single position's logits
* @param logits - Full logits array
* @param position - Which evaluated position (0 = last prefix, 1-m = evaluated tokens)
* @returns Probability distribution over vocabulary
*/
softmax(logits: Float32Array, position: number): Float32Array {
const vocabSize = this.config.vocabSize;
const offset = position * vocabSize;
const probs = new Float32Array(vocabSize);
// Find max for numerical stability
let maxLogit = -Infinity;
for (let i = 0; i < vocabSize; i++) {
maxLogit = Math.max(maxLogit, logits[offset + i]);
}
// Compute exp and sum
let sumExp = 0;
for (let i = 0; i < vocabSize; i++) {
probs[i] = Math.exp(logits[offset + i] - maxLogit);
sumExp += probs[i];
}
// Normalize
for (let i = 0; i < vocabSize; i++) {
probs[i] /= sumExp;
}
return probs;
}
/**
* Check if inferencer is ready
*/
isReady(): boolean {
return this.session !== null;
}
/**
* Destroy the session and free resources
*/
destroy(): void {
this.session = null;
console.log("[ModelInferencer] Session destroyed");
}
// Private helper methods
private printModelInfo(): void {
if (!this.session) return;
console.log("[ModelInferencer] Model Information:");
console.log(" Inputs:");
this.session.inputNames.forEach((name, i) => {
console.log(` [${i}] ${name}`);
});
console.log(" Outputs:");
this.session.outputNames.forEach((name, i) => {
console.log(` [${i}] ${name}`);
});
}
private createRandomInput(batchSize: number, seqLen: number): BigInt64Array {
const size = batchSize * seqLen;
const data = new BigInt64Array(size);
for (let i = 0; i < size; i++) {
data[i] = BigInt(Math.floor(Math.random() * this.config.vocabSize));
}
return data;
}
private padSequence(tokens: number[], targetLen: number): number[] {
const padded = [...tokens];
while (padded.length < targetLen) {
padded.push(this.PAD_TOKEN);
}
return padded.slice(0, targetLen); // Truncate if too long
}
private validateOutput(logits: OnnxTensor, batchSize: number, seqLen: number): void {
const expectedShape = [batchSize, seqLen, this.config.vocabSize];
if (logits.dims.length !== 3) {
throw new Error(`Expected 3D output, got ${logits.dims.length}D`);
}
if (
logits.dims[0] !== expectedShape[0] ||
logits.dims[1] !== expectedShape[1] ||
logits.dims[2] !== expectedShape[2]
) {
throw new Error(
`Shape mismatch! Expected [${expectedShape.join(", ")}], ` +
`got [${logits.dims.join(", ")}]`
);
}
if (logits.type !== "float32") {
throw new Error(`Expected float32 output, got ${logits.type}`);
}
}
private getPredictions(logitsData: Float32Array, numPositions: number): number[] {
const predictions: number[] = [];
for (let i = 0; i < numPositions; i++) {
let maxIdx = 0;
let maxVal = logitsData[i * this.config.vocabSize];
for (let j = 1; j < this.config.vocabSize; j++) {
const val = logitsData[i * this.config.vocabSize + j];
if (val > maxVal) {
maxVal = val;
maxIdx = j;
}
}
predictions.push(maxIdx);
}
return predictions;
}
}