File size: 1,647 Bytes
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json

import pandas as pd

from graphgen.bases import BaseLLMWrapper, BaseOperator
from graphgen.common import init_llm
from graphgen.models.extractor import SchemaGuidedExtractor
from graphgen.utils import logger, run_concurrent


class ExtractService(BaseOperator):
    def __init__(self, working_dir: str = "cache", **extract_kwargs):
        super().__init__(working_dir=working_dir, op_name="extract_service")
        self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
        self.extract_kwargs = extract_kwargs
        self.method = self.extract_kwargs.get("method")
        if self.method == "schema_guided":
            schema_file = self.extract_kwargs.get("schema_path")
            with open(schema_file, "r", encoding="utf-8") as f:
                schema = json.load(f)
            self.extractor = SchemaGuidedExtractor(self.llm_client, schema)
        else:
            raise ValueError(f"Unsupported extraction method: {self.method}")

    def process(self, batch: pd.DataFrame) -> pd.DataFrame:
        items = batch.to_dict(orient="records")
        return pd.DataFrame(self.extract(items))

    def extract(self, items: list[dict]) -> list[dict]:

        logger.info("Start extracting information from %d items", len(items))

        results = run_concurrent(
            self.extractor.extract,
            items,
            desc="Extracting information",
            unit="item",
        )
        results = self.extractor.merge_extractions(results)

        results = [
            {"_extract_id": key, "extracted_data": value}
            for key, value in results.items()
        ]
        return results