File size: 2,396 Bytes
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
61172af
 
 
31086ae
 
 
 
 
61172af
 
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61172af
31086ae
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import List

import pandas as pd

from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
from graphgen.bases.datatypes import Chunk
from graphgen.common import init_llm, init_storage
from graphgen.utils import logger

from .build_mm_kg import build_mm_kg
from .build_text_kg import build_text_kg


class BuildKGService(BaseOperator):
    def __init__(
        self, working_dir: str = "cache", graph_backend: str = "kuzu", **build_kwargs
    ):
        super().__init__(working_dir=working_dir, op_name="build_kg_service")
        self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
        self.graph_storage: BaseGraphStorage = init_storage(
            backend=graph_backend, working_dir=working_dir, namespace="graph"
        )
        self.build_kwargs = build_kwargs
        self.max_loop: int = int(self.build_kwargs.get("max_loop", 3))

    def process(self, batch: pd.DataFrame) -> pd.DataFrame:
        docs = batch.to_dict(orient="records")
        docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]

        # consume the chunks and build kg
        self.build_kg(docs)
        return pd.DataFrame([{"status": "kg_building_completed"}])

    def build_kg(self, chunks: List[Chunk]) -> None:
        """
        Build knowledge graph (KG) and merge into kg_instance
        """
        text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
        mm_chunks = [
            chunk
            for chunk in chunks
            if chunk.type in ("image", "video", "table", "formula")
        ]

        if len(text_chunks) == 0:
            logger.info("All text chunks are already in the storage")
        else:
            logger.info("[Text Entity and Relation Extraction] processing ...")
            build_text_kg(
                llm_client=self.llm_client,
                kg_instance=self.graph_storage,
                chunks=text_chunks,
                max_loop=self.max_loop,
            )
        if len(mm_chunks) == 0:
            logger.info("All multi-modal chunks are already in the storage")
        else:
            logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
            build_mm_kg(
                llm_client=self.llm_client,
                kg_instance=self.graph_storage,
                chunks=mm_chunks,
            )

        self.graph_storage.index_done_callback()