AIエージェントをREST APIとして提供する:LangChainとLangGraphによる構築入門(第3回:リファクタリング)

本記事では、前回実装した、LangChainとLangGraphをリファクタリングします。
目的は以下の2つです。

  1. 機能ごとに分割し、今後も使えるようにする。
  2. 不要な処理を消す。

結果的に大きな修正にはなりませんでしたので、第2回とあまり変わりがないですが、これでも動くのだという感覚で見て頂ければ幸いです。

1.RestAPI

大きな修正はしていません。

  • インポート文を成形したり、load_dotenvの実行をサーバ起動の直前にしました。
from traceback import print_exception
from dotenv import load_dotenv
from pydantic import BaseModel
from typing import Optional

from fastapi import FastAPI, HTTPException

from search_agent import define_langraph_workflow, AgentState

load_dotenv(dotenv_path='.env')
app = FastAPI()

class Query(BaseModel):
    text: str
    max_length: Optional[int] = 50

@app.post("/generate")
async def generate_text(query: Query):
    """
    テキスト生成エンドポイント。
    """
    try:
        # LangGraphワークフローの定義
        workflow = define_langraph_workflow()

        # ワークフローの実行
        result = workflow.invoke(AgentState(input=query.text))
        print(result)  # AgentStateインスタンス

        # 結果の整形
        summary = result['summary']  # 最終結果はsummaryに格納されていると仮定
        return {"result": summary}  # 要約結果を返す

    except Exception as e:
        print_exception(e)
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """
    ヘルスチェックエンドポイント。
    """
    return {"status": "ok"}

コードの説明:

  • print(result) : 実運用の時は削除するか、loggerに出力します。動作確認の時には AgentStateインスタンスを確認するため出力しているとDEBUGしやすいです。

2.Agentの処理部分

  • 処理を今後も使えるように、AgentのNodeごとのクラスにしました。
  • 実行方法を統一するためにcallメソッドで実行できるようにしました。

from dataclasses import dataclass
from typing import Optional

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from langchain_core.output_parsers import StrOutputParser

from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph

@dataclass
class AgentState:
    """
    LangGraphの状態を定義する。
    """
    input: Optional[str] = None
    keywords: Optional[str] = None
    search_results: Optional[str] = None
    summary: Optional[str] = None

class LLMNode:
    def __init__(self, llm_model="gpt-4o-mini"): # デフォルトをgpt-4o-miniにする
        """
        LLMNodeKeywordMakerの初期化。

        Args:
            llm_model (str): 使用するLLMモデルの名前。
        """
        self.llm = self._select_llm(llm_model)

    def _select_llm(self, llm_model):
        """
        LLMを選択する。
        """
        if llm_model == "gpt-4o-mini":
            # OpenAIのモデルを利用
            try:
                llm = ChatOpenAI(model="gpt-4o", temperature=0.0)
                return llm
            except KeyError:
                raise KeyError("OPENAI_API_KEYが設定されていません。")
        else:
            raise ValueError(f"サポートされていないLLMモデル: {self.llm_model}")

class LLMNodeSearchKeywordMaker(LLMNode):

    def __call__(self, input: str) -> dict:
        """
        入力されたテキストから、WEB検索に適したキーワードを生成する。

        Args:
            input: 入力テキスト。

        Returns:
            keywords (str): WEB検索キーワード。
        """
        prompt = ChatPromptTemplate.from_template("入力されたテキスト: {input_text} から、WEB検索に適したキーワードを3つ生成してください。")
        chain = prompt | self.llm | StrOutputParser()
        keywords = chain.invoke({"input_text": input})
        return {"keywords": keywords}

class LLMNodeResultsMaker(LLMNode):

    def __call__(self, search_results: str) -> dict:
        """
        検索結果を要約する。

        Args:
            search_results: 検索結果。

        Returns:
            summary(str): 要約された結果。
        """
        prompt = ChatPromptTemplate.from_template("以下の検索結果: {search_results} を要約してください。3文以内で。")
        chain = prompt | self.llm | StrOutputParser()
        summary = chain.invoke({"search_results": search_results})
        return {"summary": summary}

class WebNodeSearch:
    def __init__(self):
        self.search = DuckDuckGoSearchAPIWrapper()

    def __call__(self, keywords: str) -> dict:
        """
        WEB検索キーワードを用いて、DuckDuckGoで検索を行う。

        Args:
            keywords: WEB検索キーワード。

        Returns:
            search_results(str): 検索結果。
        """
        results = self.search.run(keywords)
        return {"search_results": results}

def define_langraph_workflow() -> CompiledStateGraph:
    """
    LangGraphのワークフローを定義する。

    Returns:
        StateGraph: LangGraphのワークフロー。
    """
    builder = StateGraph(AgentState)

    builder.add_node("create_web_search_keywords", LLMNodeSearchKeywordMaker("gpt-4o-mini"))
    builder.add_node("web_search", WebNodeSearch())
    builder.add_node("summarize_results", LLMNodeResultsMaker("gpt-4o-mini"))

    builder.set_entry_point("create_web_search_keywords")

    builder.add_edge("create_web_search_keywords", "web_search")
    builder.add_edge("web_search", "summarize_results")
    builder.add_edge("summarize_results", END)

    graph = builder.compile()
    return graph

コードの説明:

  • Nodeになるクラスごとに処理を分けました。
  • 実行はcallメソッドに統一しました。
    • LLMNodeSearchKeywordMaker : LLMNodeを親クラスにしてモデルアクセサの作成処理は共通化しました。
    • LLMNodeResultsMaker : LLMNodeを親クラスにしてモデルアクセサの作成処理は共通化しました。
    • WebNodeSearch : Web検索だけのNodeにしました。
  • define_langraph_workflow : 上記のクラス修正に合わせて、Graphへの設定を修正しました。

まとめ

今回はリファクタリングのみを実施しました。少しは奇麗になったと思いますが、処理を付け足していくと更にリファクタリングが必要になりそうです。 初めから奇麗に作ることを心掛けていればいいのですが、ついつい動けばいいやという気持ちで作成してしまうのが反省点です。(変えるつもりはないですが)

コメント

タイトルとURLをコピーしました