Spaces:
Running
Running
| /** | |
| * ONNX Model Inferencer (Frontend/Backend Common) | |
| * | |
| * Platform-agnostic inference logic that accepts ONNX session from platform-specific code. | |
| * No direct dependency on onnxruntime packages - uses dependency injection pattern. | |
| * | |
| * Adapted from Node.js test_inference.js for cross-platform use | |
| * Provides causal language model inference using GPT-2 ONNX model | |
| */ | |
| /** | |
| * Minimal ONNX Tensor interface (platform-agnostic) | |
| */ | |
| export interface OnnxTensor { | |
| readonly data: number[] | Float32Array | Int32Array | BigInt64Array | Uint8Array; | |
| readonly dims: readonly number[]; | |
| readonly type: string; | |
| } | |
| /** | |
| * Minimal ONNX Session interface (platform-agnostic) | |
| */ | |
| export interface OnnxSession { | |
| readonly inputNames: readonly string[]; | |
| readonly outputNames: readonly string[]; | |
| run(feeds: Record<string, OnnxTensor>): Promise<Record<string, OnnxTensor>>; | |
| } | |
| /** | |
| * Tensor constructor interface (platform-specific) | |
| */ | |
| export interface TensorConstructor { | |
| new ( | |
| type: string, | |
| data: BigInt64Array | Float32Array | Int32Array | Uint8Array, | |
| dims: number[] | |
| ): OnnxTensor; | |
| } | |
| /** | |
| * Configuration for the inferencer | |
| */ | |
| export interface InferencerConfig { | |
| vocabSize: number; | |
| seqLen: number; | |
| modelPath?: string; // Optional, for reference | |
| } | |
| /** | |
| * Inference result containing generated tokens and metadata | |
| */ | |
| export interface InferenceResult { | |
| tokens: number[]; | |
| text: string; | |
| logits: Float32Array; | |
| inferenceTime: number; | |
| } | |
| /** | |
| * Evaluation mode inputs for tree attention | |
| */ | |
| export interface EvaluationInputs { | |
| prefixIds: number[]; // Prefix sequence (causal context) | |
| evaluatedIds: number[]; // Tokens to evaluate in tree | |
| evaluatedMask: number[]; // Attention mask [m×m] flattened | |
| } | |
| /** | |
| * Evaluation mode output | |
| */ | |
| export interface EvaluationOutput { | |
| logits: Float32Array; // [m+1, vocab_size] flattened | |
| numEvaluated: number; // m | |
| } | |
| /** | |
| * Model Inferencer for Causal Language Model | |
| * Compatible with both frontend (onnxruntime-web) and backend (onnxruntime-node) | |
| */ | |
| export class ModelInferencer { | |
| private session: OnnxSession | null = null; | |
| private config: InferencerConfig; | |
| private TensorClass: TensorConstructor; | |
| // TGN tokenizer: byte-level (0-255) + PAD(256) + START(257) + END(258) | |
| private readonly PAD_TOKEN = 256; | |
| private readonly START_TOKEN = 257; | |
| private readonly END_TOKEN = 258; | |
| constructor(TensorClass: TensorConstructor, config: Partial<InferencerConfig> = {}) { | |
| this.TensorClass = TensorClass; | |
| this.config = { | |
| vocabSize: 259, | |
| seqLen: 256, | |
| ...config | |
| }; | |
| } | |
| /** | |
| * Set the inference session (created by platform-specific code) | |
| */ | |
| setSession(session: OnnxSession): void { | |
| this.session = session; | |
| console.log("[ModelInferencer] ✓ Session set successfully"); | |
| this.printModelInfo(); | |
| } | |
| /** | |
| * Run basic inference test | |
| */ | |
| async testBasicInference(): Promise<InferenceResult> { | |
| if (!this.session) { | |
| throw new Error("Inferencer not initialized. Call setSession() first."); | |
| } | |
| console.log("[ModelInferencer] Running basic inference test..."); | |
| const batchSize = 1; | |
| const seqLen = this.config.seqLen; | |
| // Create random input | |
| const inputIds = this.createRandomInput(batchSize, seqLen); | |
| const inputTensor = new this.TensorClass("int64", inputIds, [batchSize, seqLen]); | |
| // Run inference | |
| const startTime = performance.now(); | |
| const results = await this.session.run({ input_ids: inputTensor }); | |
| const inferenceTime = performance.now() - startTime; | |
| // Get logits | |
| const logits = results.logits; | |
| // Validate output | |
| this.validateOutput(logits, batchSize, seqLen); | |
| // Get predictions | |
| const predictions = this.getPredictions(logits.data as Float32Array, batchSize * seqLen); | |
| // Convert tokens to text | |
| const text = String.fromCharCode(...predictions.slice(0, 100)); | |
| console.log("[ModelInferencer] Inference completed:"); | |
| console.log(` Input shape: [${inputTensor.dims.join(", ")}]`); | |
| console.log(` Output shape: [${logits.dims.join(", ")}]`); | |
| console.log(` Output dtype: ${logits.type}`); | |
| console.log(` Inference time: ${inferenceTime.toFixed(2)}ms`); | |
| console.log(` Sample predictions: [${predictions.slice(0, 10).join(", ")}]`); | |
| const logitsArray = Array.from(logits.data as Float32Array); | |
| console.log( | |
| ` Logits range: [${Math.min(...logitsArray).toFixed(3)}, ${Math.max(...logitsArray).toFixed(3)}]` | |
| ); | |
| return { | |
| tokens: predictions, | |
| text, | |
| logits: logits.data as Float32Array, | |
| inferenceTime | |
| }; | |
| } | |
| /** | |
| * Generate tokens autoregressively from a prompt | |
| */ | |
| async generateText(prompt: string, numTokens: number = 10): Promise<InferenceResult> { | |
| if (!this.session) { | |
| throw new Error("Inferencer not initialized. Call setSession() first."); | |
| } | |
| console.log(`[ModelInferencer] Generating ${numTokens} tokens from prompt: "${prompt}"`); | |
| // Convert prompt to token IDs (byte values) | |
| const promptTokens = Array.from(prompt).map((c) => c.charCodeAt(0)); | |
| console.log(` Prompt tokens (${promptTokens.length}): [${promptTokens.join(", ")}]`); | |
| // Start with prompt tokens | |
| const sequence = [...promptTokens]; | |
| const times: number[] = []; | |
| // Generate tokens | |
| for (let i = 0; i < numTokens; i++) { | |
| // Pad sequence to fixed length | |
| const paddedSequence = this.padSequence(sequence, this.config.seqLen); | |
| // Create input tensor | |
| const inputIds = new BigInt64Array(paddedSequence.map((t) => BigInt(t))); | |
| const inputTensor = new this.TensorClass("int64", inputIds, [1, this.config.seqLen]); | |
| // Run inference | |
| const startTime = performance.now(); | |
| const results = await this.session.run({ input_ids: inputTensor }); | |
| times.push(performance.now() - startTime); | |
| // Get prediction at the last non-padded position | |
| const logits = results.logits.data as Float32Array; | |
| const lastPos = sequence.length - 1; // Position before padding | |
| const offset = lastPos * this.config.vocabSize; | |
| // Find token with highest logit | |
| let maxIdx = 0; | |
| let maxVal = logits[offset]; | |
| for (let j = 1; j < this.config.vocabSize; j++) { | |
| if (logits[offset + j] > maxVal) { | |
| maxVal = logits[offset + j]; | |
| maxIdx = j; | |
| } | |
| } | |
| sequence.push(maxIdx); | |
| // Stop if END token is generated | |
| if (maxIdx === this.END_TOKEN) { | |
| console.log(" Generated END token, stopping..."); | |
| break; | |
| } | |
| } | |
| // Convert generated tokens to text | |
| const generatedText = String.fromCharCode(...sequence); | |
| const avgTime = times.reduce((a, b) => a + b, 0) / times.length; | |
| console.log(`[ModelInferencer] Generation complete:`); | |
| console.log(` Generated text: "${generatedText}"`); | |
| console.log(` Token sequence (${sequence.length}): [${sequence.join(", ")}]`); | |
| console.log(` Avg inference time: ${avgTime.toFixed(2)}ms`); | |
| console.log(` Tokens/sec: ${(1000 / avgTime).toFixed(2)}`); | |
| return { | |
| tokens: sequence, | |
| text: generatedText, | |
| logits: new Float32Array(), // Not returning full logits for generation | |
| inferenceTime: avgTime | |
| }; | |
| } | |
| /** | |
| * Get model information | |
| */ | |
| getModelInfo(): { inputs: string[]; outputs: string[] } | null { | |
| if (!this.session) return null; | |
| return { | |
| inputs: [...this.session.inputNames], | |
| outputs: [...this.session.outputNames] | |
| }; | |
| } | |
| /** | |
| * Get configuration | |
| */ | |
| getConfig(): InferencerConfig { | |
| return this.config; | |
| } | |
| /** | |
| * Run inference with token array input | |
| * Returns raw logits as Float32Array | |
| */ | |
| async runInference(tokens: number[]): Promise<Float32Array> { | |
| if (!this.session) { | |
| throw new Error("Inferencer not initialized. Call setSession() first."); | |
| } | |
| const seqLen = this.config.seqLen; | |
| // Prepend START_TOKEN to input | |
| const tokensWithStart = [this.START_TOKEN, ...tokens]; | |
| // Pad to fixed length | |
| const paddedTokens = new BigInt64Array(seqLen); | |
| for (let i = 0; i < seqLen; i++) { | |
| paddedTokens[i] = | |
| i < tokensWithStart.length ? BigInt(tokensWithStart[i]) : BigInt(this.PAD_TOKEN); | |
| } | |
| // Create input tensor | |
| const inputTensor = new this.TensorClass("int64", paddedTokens, [1, seqLen]); | |
| // Run inference | |
| const results = await this.session.run({ input_ids: inputTensor }); | |
| return results.logits.data as Float32Array; | |
| } | |
| /** | |
| * Run tree attention inference (evaluation mode) | |
| * For models exported with --evaluation flag | |
| * @param inputs - Prefix, evaluated tokens, and attention mask | |
| * @returns Logits for each evaluated position | |
| */ | |
| async runEvaluationInference(inputs: EvaluationInputs): Promise<EvaluationOutput> { | |
| if (!this.session) { | |
| throw new Error("Inferencer not initialized. Call setSession() first."); | |
| } | |
| const { prefixIds, evaluatedIds, evaluatedMask } = inputs; | |
| const batchSize = 1; | |
| const prefixLen = prefixIds.length; | |
| const m = evaluatedIds.length; | |
| // Convert to BigInt64Array for ONNX int64 tensors | |
| const prefixIdsArray = new BigInt64Array(batchSize * prefixLen); | |
| for (let i = 0; i < prefixLen; i++) { | |
| prefixIdsArray[i] = BigInt(prefixIds[i]); | |
| } | |
| const evaluatedIdsArray = new BigInt64Array(batchSize * m); | |
| for (let i = 0; i < m; i++) { | |
| evaluatedIdsArray[i] = BigInt(evaluatedIds[i]); | |
| } | |
| // Mask is Float32Array | |
| const maskArray = new Float32Array(m * m); | |
| for (let i = 0; i < m * m; i++) { | |
| maskArray[i] = evaluatedMask[i]; | |
| } | |
| // Create ONNX tensors | |
| const prefixIdsTensor = new this.TensorClass("int64", prefixIdsArray, [ | |
| batchSize, | |
| prefixLen | |
| ]); | |
| const evaluatedIdsTensor = new this.TensorClass("int64", evaluatedIdsArray, [batchSize, m]); | |
| const evaluatedMaskTensor = new this.TensorClass("float32", maskArray, [1, m, m]); | |
| // Run inference | |
| const results = await this.session.run({ | |
| prefix_ids: prefixIdsTensor, | |
| evaluated_ids: evaluatedIdsTensor, | |
| evaluated_mask: evaluatedMaskTensor | |
| }); | |
| // Extract logits | |
| const logits = results.logits.data as Float32Array; | |
| // Output shape: [batch, m+1, vocab_size] | |
| // We return flattened array and num_evaluated for reshaping | |
| return { | |
| logits, | |
| numEvaluated: m | |
| }; | |
| } | |
| /** | |
| * Compute softmax for a single position's logits | |
| * @param logits - Full logits array | |
| * @param position - Which evaluated position (0 = last prefix, 1-m = evaluated tokens) | |
| * @returns Probability distribution over vocabulary | |
| */ | |
| softmax(logits: Float32Array, position: number): Float32Array { | |
| const vocabSize = this.config.vocabSize; | |
| const offset = position * vocabSize; | |
| const probs = new Float32Array(vocabSize); | |
| // Find max for numerical stability | |
| let maxLogit = -Infinity; | |
| for (let i = 0; i < vocabSize; i++) { | |
| maxLogit = Math.max(maxLogit, logits[offset + i]); | |
| } | |
| // Compute exp and sum | |
| let sumExp = 0; | |
| for (let i = 0; i < vocabSize; i++) { | |
| probs[i] = Math.exp(logits[offset + i] - maxLogit); | |
| sumExp += probs[i]; | |
| } | |
| // Normalize | |
| for (let i = 0; i < vocabSize; i++) { | |
| probs[i] /= sumExp; | |
| } | |
| return probs; | |
| } | |
| /** | |
| * Check if inferencer is ready | |
| */ | |
| isReady(): boolean { | |
| return this.session !== null; | |
| } | |
| /** | |
| * Destroy the session and free resources | |
| */ | |
| destroy(): void { | |
| this.session = null; | |
| console.log("[ModelInferencer] Session destroyed"); | |
| } | |
| // Private helper methods | |
| private printModelInfo(): void { | |
| if (!this.session) return; | |
| console.log("[ModelInferencer] Model Information:"); | |
| console.log(" Inputs:"); | |
| this.session.inputNames.forEach((name, i) => { | |
| console.log(` [${i}] ${name}`); | |
| }); | |
| console.log(" Outputs:"); | |
| this.session.outputNames.forEach((name, i) => { | |
| console.log(` [${i}] ${name}`); | |
| }); | |
| } | |
| private createRandomInput(batchSize: number, seqLen: number): BigInt64Array { | |
| const size = batchSize * seqLen; | |
| const data = new BigInt64Array(size); | |
| for (let i = 0; i < size; i++) { | |
| data[i] = BigInt(Math.floor(Math.random() * this.config.vocabSize)); | |
| } | |
| return data; | |
| } | |
| private padSequence(tokens: number[], targetLen: number): number[] { | |
| const padded = [...tokens]; | |
| while (padded.length < targetLen) { | |
| padded.push(this.PAD_TOKEN); | |
| } | |
| return padded.slice(0, targetLen); // Truncate if too long | |
| } | |
| private validateOutput(logits: OnnxTensor, batchSize: number, seqLen: number): void { | |
| const expectedShape = [batchSize, seqLen, this.config.vocabSize]; | |
| if (logits.dims.length !== 3) { | |
| throw new Error(`Expected 3D output, got ${logits.dims.length}D`); | |
| } | |
| if ( | |
| logits.dims[0] !== expectedShape[0] || | |
| logits.dims[1] !== expectedShape[1] || | |
| logits.dims[2] !== expectedShape[2] | |
| ) { | |
| throw new Error( | |
| `Shape mismatch! Expected [${expectedShape.join(", ")}], ` + | |
| `got [${logits.dims.join(", ")}]` | |
| ); | |
| } | |
| if (logits.type !== "float32") { | |
| throw new Error(`Expected float32 output, got ${logits.type}`); | |
| } | |
| } | |
| private getPredictions(logitsData: Float32Array, numPositions: number): number[] { | |
| const predictions: number[] = []; | |
| for (let i = 0; i < numPositions; i++) { | |
| let maxIdx = 0; | |
| let maxVal = logitsData[i * this.config.vocabSize]; | |
| for (let j = 1; j < this.config.vocabSize; j++) { | |
| const val = logitsData[i * this.config.vocabSize + j]; | |
| if (val > maxVal) { | |
| maxVal = val; | |
| maxIdx = j; | |
| } | |
| } | |
| predictions.push(maxIdx); | |
| } | |
| return predictions; | |
| } | |
| } | |