|
|
--- |
|
|
library_name: diffusers |
|
|
tags: |
|
|
- fp8 |
|
|
- safetensors |
|
|
- precision-recovery |
|
|
- diffusion |
|
|
- converted-by-gradio |
|
|
--- |
|
|
# FP8 Model with Precision Recovery |
|
|
- **Source**: `https://huggingface.co/LifuWang/DistillT5` |
|
|
- **File**: `model.safetensors` |
|
|
- **FP8 Format**: `E5M2` |
|
|
- **Architecture**: all |
|
|
- **Precision Recovery Type**: LoRA |
|
|
- **Precision Recovery File**: `model-lora-r64-all.safetensors` if available |
|
|
- **FP8 File**: `model-fp8-e5m2.safetensors` |
|
|
|
|
|
## Usage (Inference) |
|
|
```python |
|
|
from safetensors.torch import load_file |
|
|
import torch |
|
|
|
|
|
# Load FP8 model |
|
|
fp8_state = load_file("model-fp8-e5m2.safetensors") |
|
|
|
|
|
# Load precision recovery file if available |
|
|
recovery_state = {} |
|
|
if "model-lora-r64-all.safetensors": |
|
|
recovery_state = load_file("model-lora-r64-all.safetensors") |
|
|
|
|
|
# Reconstruct high-precision weights |
|
|
reconstructed = {} |
|
|
for key in fp8_state: |
|
|
# Dequantize FP8 to target precision |
|
|
fp_weight = fp8_state[key].to(torch.float32) |
|
|
|
|
|
if recovery_state: |
|
|
# For LoRA approach |
|
|
if f"lora_A.{key}" in recovery_state and f"lora_B.{key}" in recovery_state: |
|
|
A = recovery_state[f"lora_A.{key}"].to(torch.float32) |
|
|
B = recovery_state[f"lora_B.{key}"].to(torch.float32) |
|
|
error_correction = B @ A |
|
|
reconstructed[key] = fp_weight + error_correction |
|
|
# For correction factor approach |
|
|
elif f"correction.{key}" in recovery_state: |
|
|
correction = recovery_state[f"correction.{key}"].to(torch.float32) |
|
|
reconstructed[key] = fp_weight + correction |
|
|
else: |
|
|
reconstructed[key] = fp_weight |
|
|
else: |
|
|
reconstructed[key] = fp_weight |
|
|
|
|
|
print("Model reconstructed with FP8 error recovery") |
|
|
``` |
|
|
|
|
|
> **Note**: This precision recovery targets FP8 quantization errors. |
|
|
> Average quantization error: 0.052733 |
|
|
|