Spaces:
Running
Running
| // | |
| // SPDX-FileCopyrightText: Hadad <[email protected]> | |
| // SPDX-License-Identifier: Apache-2.0 | |
| // | |
| import axios from 'axios'; | |
| import config from '../../config.js'; | |
| import { generateId } from '../utils/idGenerator.js'; | |
| import { | |
| getStorage, | |
| setStorage, | |
| getActiveGeneration, | |
| setActiveGeneration, | |
| deleteActiveGeneration | |
| } from './storageManager.js'; | |
| import { | |
| sendToSession | |
| } from './websocketManager.js'; | |
| const updateProgress = (sessionId, progress) => { | |
| const data = getStorage(sessionId); | |
| if (data) { | |
| data.progress = progress; | |
| setStorage(sessionId, data); | |
| sendToSession(sessionId, { | |
| type: 'progressUpdate', | |
| progress | |
| }); | |
| } | |
| }; | |
| const createProgressUpdater = (sessionId) => { | |
| return setInterval(() => { | |
| const data = getStorage(sessionId); | |
| if (data && | |
| data.isGenerating && | |
| data.progress < config.generation.maxProgress) { | |
| const increment = Math.random() * 8; | |
| const newProgress = Math.min( | |
| config.generation.maxProgress, | |
| data.progress + increment | |
| ); | |
| updateProgress(sessionId, newProgress); | |
| } | |
| }, config.generation.progressInterval); | |
| }; | |
| const createImageObject = ( | |
| base64Data, | |
| prompt, | |
| model, | |
| size | |
| ) => ({ | |
| id: generateId(), | |
| base64: base64Data, | |
| prompt, | |
| model, | |
| size | |
| }); | |
| const callImageApi = async ( | |
| prompt, | |
| model, | |
| size, | |
| signal | |
| ) => { | |
| return await axios.post( | |
| config.api.baseUrl, | |
| { | |
| model, | |
| prompt, | |
| size, | |
| response_format: 'b64_json', | |
| n: 1 | |
| }, | |
| { | |
| headers: { | |
| 'Authorization': `Bearer ${config.api.key}`, | |
| 'Content-Type': 'application/json' | |
| }, | |
| signal, | |
| timeout: config.api.timeout, | |
| maxBodyLength: config.limits.maxContentLength, | |
| maxContentLength: config.limits.maxContentLength | |
| } | |
| ); | |
| }; | |
| export const generateImage = async ( | |
| sessionId, | |
| prompt, | |
| model, | |
| size | |
| ) => { | |
| const controller = new AbortController(); | |
| setActiveGeneration(sessionId, controller); | |
| const progressInterval = createProgressUpdater(sessionId); | |
| setTimeout(async () => { | |
| try { | |
| const response = await callImageApi( | |
| prompt, | |
| model, | |
| size, | |
| controller.signal | |
| ); | |
| const data = getStorage(sessionId); | |
| if (!data) return; | |
| updateProgress( | |
| sessionId, | |
| config.generation.maxProgress | |
| ); | |
| if (response.data?.data?.length > 0) { | |
| const base64 = response.data.data[0].b64_json; | |
| const newImage = createImageObject( | |
| base64, | |
| prompt, | |
| model, | |
| size | |
| ); | |
| data.images.unshift(newImage); | |
| } | |
| data.isGenerating = false; | |
| data.progress = 0; | |
| setStorage(sessionId, data); | |
| sendToSession(sessionId, { | |
| type: 'generationComplete', | |
| images: data.images | |
| }); | |
| } catch (error) { | |
| const data = getStorage(sessionId); | |
| if (!data) return; | |
| if (error.name !== 'CanceledError' && | |
| error.code !== 'ERR_CANCELED') { | |
| data.error = | |
| `The request to the ${model} model was ` + | |
| `unsuccessful, possibly due to high ` + | |
| `server load. Please try again later.`; | |
| sendToSession(sessionId, { | |
| type: 'generationError', | |
| error: data.error | |
| }); | |
| } | |
| data.isGenerating = false; | |
| data.progress = 0; | |
| setStorage(sessionId, data); | |
| } finally { | |
| clearInterval(progressInterval); | |
| deleteActiveGeneration(sessionId); | |
| } | |
| }, config.generation.startDelay); | |
| }; | |
| export const cancelGeneration = (sessionId) => { | |
| const controller = getActiveGeneration(sessionId); | |
| if (controller) { | |
| controller.abort(); | |
| deleteActiveGeneration(sessionId); | |
| const data = getStorage(sessionId); | |
| if (data) { | |
| data.isGenerating = false; | |
| data.progress = 0; | |
| setStorage(sessionId, data); | |
| } | |
| } | |
| }; |