crewAI icon indicating copy to clipboard operation
crewAI copied to clipboard

[BUG] Overhaul MySQLSearchTool: Fixing Config, Embedchain Integration, and Robustness

Open mouramax opened this issue 9 months ago • 0 comments

Description

The current MySQLSearchTool is kinda broken and needs a serious look. It's got issues with how it handles config (doesn't align with Embedchain's expected MySQLLoader format), messes up basic database URI parsing (especially with special characters like # in passwords), and the Pydantic setup feels a bit off, leading to validation errors and general fragility. It prevents the tool from being used reliably out-of-the-box.

Steps to Reproduce

Assuming you've got a standard MySQL setup running (like the pets DB from the MySQL Getting Started tutorial with a populated cats table), try running this snippet with the current crewai-tools version. Make sure to replace USER:PASSWORD with valid credentials, potentially including a special character in the password (like #) to expose the parsing issue.

from crewai_tools import MySQLSearchTool
import os

# Satisfy both LiteLLM and Embedchain
os.environ["GEMINI_API_KEY"] = "YOUR_KEY"
os.environ["GOOGLE_API_KEY"] = os.environ["GEMINI_API_KEY"]

embedchain_config = {
    "embedder": {
        "provider": "google",
        "config": {
            "model": "models/text-embedding-004",
            "task_type": "RETRIEVAL_DOCUMENT"
        }
    }
}

rag_tool = MySQLSearchTool(
    config=embedchain_config,
    db_uri="mysql://USER:PASSWORD@localhost:3306/pets",
    table_name="cats"
)

#
# 1 - Test if `RagTool.run()` works standalone
#

user_question = "Who owns Cookie?"

relevant_chunks = rag_tool.run(user_question)

print("--- RagTool.run() Result ---")
print(relevant_chunks)
print("----------------------------")

Expected behavior

The tool should initialize without throwing ValidationErrors or other exceptions. It needs to:

  1. Correctly parse the db_uri, reliably handling standard URI components and common special characters (like #, @) within the password field without requiring manual pre-encoding.
  2. Use the parsed credentials and DB info to configure an embedchain.loaders.mysql.MySQLLoader instance correctly (passing a config dictionary, not url).
  3. Successfully use this loader via the EmbedchainAdapter (the default RagTool adapter) to load data from the specified table_name using a SELECT * FROM ... query during or shortly after initialization.
  4. Be ready to execute semantic search queries via its run() method against the loaded data.

Screenshots/Code snippets

A screenshot of the error is in the "Evidence" section.

Operating System

Ubuntu 24.04

Python Version

3.12

crewAI Version

0.114.0

crewAI Tools Version

0.40.1

Virtual Environment

Venv

Evidence

Image

Possible Solution

Here is a refactoring suggestion. The goal was to make it robust, align strictly with Pydantic best practices and the Embedchain API, and handle that tricky URI parsing properly using regex. No more weird kwargs propagation either (assuming the parent RagTool is also fixed as discussed elsewhere).

Key improvements in this proposed code:

  • Uses Pydantic Field/PrivateAttr cleanly for config vs state.
  • Validates/parses db_uri with regex to handle special characters robustly.
  • Initializes MySQLLoader the way Embedchain expects (using config dict).
  • Eliminates kwargs issues in the RAG flow (requires parent RagTool fix).
  • Maintains backward compatibility with the db_uri param.
  • Adds specific error messages for better DX.
import re
from typing import Any, Dict, Optional, Type

from embedchain.loaders.mysql import MySQLLoader
from pydantic import (
    BaseModel,
    Field,
    PrivateAttr,
    ValidationInfo,
    field_validator,
)
from pydantic_core import PydanticCustomError

from crewai_tools import RagTool # Original: from ..rag.rag_tool import RagTool


class MySQLSearchToolSchema(BaseModel):
    """Input schema for the MySQLSearchTool."""

    search_query: str = Field(..., description="Mandatory semantic search query.")


class MySQLSearchTool(RagTool):
    """
    A tool for performing semantic searches on the content of a specific
    table within a MySQL database.

    Requires a database URI and table name during initialization. Data is
    loaded lazily upon the first search execution.
    """

    # --- Pydantic Field Declarations ---

    name: str = "Search MySQL Database Table Content"
    description: str = "Performs semantic search on a specific MySQL table's content."
    args_schema: Type[BaseModel] = MySQLSearchToolSchema
    db_uri: str = Field(
        ...,
        description=(
            "Mandatory database connection URI. Format: "
            "mysql://[user[:password]@]host[:port]/database."
        ),
    )
    table_name: str = Field(
        ...,
        description="The specific table name to search within the database.",
    )

    # --- Private Attributes ---

    _mysql_loader: Optional[MySQLLoader] = PrivateAttr(default=None)
    _parsed_db_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
    # Flag to track if data for the initial table has been loaded.
    _initial_data_added: bool = PrivateAttr(default=False)

    # --- Validator for Database URI ---

    @field_validator("db_uri")
    @classmethod
    def _validate_db_uri_format(cls, v: str, info: ValidationInfo) -> str:
        """
        Validates the MySQL URI format using a regular expression.

        Args:
            v: The database URI string to validate.
            info: Pydantic validation information (unused here).

        Returns:
            The validated database URI string.

        Raises:
            PydanticCustomError: If the URI format is invalid.
        """
        try:
            cls._parse_uri_to_config(v)
            return v
        except ValueError as e:
            raise PydanticCustomError(
                "value_error",
                "Invalid MySQL URI: {error}. Expected format: "
                "mysql://[user[:password]@]host[:port]/database. "
                "Ensure all parts are present and correctly formatted. "
                "URL-encode special characters if needed.",
                {"error": str(e)},
            ) from e
        except Exception as e:
            # Catch unexpected errors during validation
            raise PydanticCustomError(
                "value_error",
                "An unexpected error occurred validating the MySQL URI: '{uri}'.",
                {"uri": v},
            ) from e

    # --- Post-Initialization Hook ---

    def model_post_init(self, __context: Any) -> None:
        """
        Initializes the MySQL loader after Pydantic validation.

        Parses the validated URI, creates the MySQLLoader instance, and
        updates the tool's description. Defers adding data until the first run.

        Args:
            __context: Pydantic model validation context (unused here).

        Raises:
            RuntimeError: If URI parsing or loader initialization fails.
        """
        try:
            self._parsed_db_config = self._parse_uri_to_config(self.db_uri)
        except ValueError as e:
            # Should not happen if validation passed, but handle defensively.
            raise RuntimeError(
                f"Could not parse database URI '{self.db_uri}' "
                f"during tool initialization: {e}"
            ) from e

        if not self._parsed_db_config:
            # Should be caught by the exception above, but double-check.
            raise RuntimeError("Database configuration parsing failed unexpectedly.")

        # Initialize the Embedchain MySQLLoader
        try:
            self._mysql_loader = MySQLLoader(config=self._parsed_db_config)
        except Exception as e:
            raise RuntimeError(
                f"Failed to initialize the underlying MySQLLoader: {e}"
            ) from e

        # Update description to be specific to the initialized table.
        self.description = (
            f"Performs semantic search on the '{self.table_name}' table "
            f"in the specified MySQL database. Input is the search query."
        )
        # Data loading is deferred to the first _run call.

    # --- Static Helper Method for URI Parsing (using Regex) ---

    @staticmethod
    def _parse_uri_to_config(db_uri: str) -> Dict[str, Any]:
        """
        Parses a MySQL URI into a config dictionary for MySQLLoader.

        Uses regular expressions for robust parsing, handling optional
        user, password, and port components.

        Args:
            db_uri: The MySQL connection string (e.g.,
                    mysql://user:pass@host:port/db).

        Returns:
            A dictionary with 'host', 'port', 'database', and optionally
            'user' and 'password'.

        Raises:
            ValueError: If the URI format is invalid or missing required parts
                        (host, database).
        """
        # Regex breakdown:
        # ^mysql://              - Anchor to start, match scheme
        # (?:                    - Optional non-capturing group for auth
        #   ([^:/@]+)           - Group 1: Username (no :, /, @)
        #   (?::([^@]*))?        - Optional non-capturing group for password
        #     :                  -   Literal colon
        #     ([^@]*)            -   Group 2: Password (no @)
        #   @                    - Literal @ separator
        # )?                     - End optional auth group
        # ([^:/?#]+)             - Group 3: Host (no :, /, ?, #)
        # (?::(\d+))?            - Optional non-capturing group for port
        #   :                    -   Literal colon
        #   (\d+)                -   Group 4: Port (digits)
        # /                      - Literal / separator
        # ([^?#]+)               - Group 5: Database (no ?, #)
        # (?:[?#].*)?            - Optional non-capturing group for query/fragment
        # $                      - Anchor to end
        pattern = re.compile(
            r"^mysql://"
            r"(?:([^:/@]+)(?::([^@]*))?@)?"
            r"([^:/?#]+)"
            r"(?::(\d+))?"
            r"/([^?#]+)"
            r"(?:[?#].*)?"
            r"$"
        )
        match = pattern.match(db_uri)

        if not match:
            raise ValueError(
                "URI does not match expected format: "
                "mysql://[user[:password]@]host[:port]/database"
            )

        groups = match.groups()
        username, password, hostname, port_str, database = groups

        if not hostname:
            # Should be caught by regex, but defensive check.
            raise ValueError("Hostname missing in the URI.")
        if not database:
            raise ValueError("Database name missing in the URI path.")

        try:
            port = int(port_str) if port_str else 3306  # Default MySQL port
        except ValueError:
            raise ValueError(f"Invalid port number: '{port_str}'.") from None

        config: Dict[str, Any] = {
            "host": hostname,
            "port": port,
            "database": database,
        }
        if username is not None:
            config["user"] = username
            # Password can be None if only username is provided (e.g., mysql://user@host/db)
            if password is not None:
                config["password"] = password

        return config

    # --- Core RagTool Methods ---

    def add(self, table_name: Optional[str] = None) -> None:
        """
        Adds data from a MySQL table to the RAG adapter.

        Defaults to the table specified during tool initialization if
        `table_name` is not provided.

        Args:
            table_name: The name of the table to load data from.

        Raises:
            ValueError: If no table name can be determined.
            RuntimeError: If the loader/adapter isn't ready or adding fails.
        """
        target_table = table_name or self.table_name
        if not target_table:
            raise ValueError("Table name must be provided during init or to add().")

        if not self._mysql_loader:
            raise RuntimeError("MySQLLoader is not initialized.")
        if not hasattr(self.adapter, "add"):
            adapter_type = type(self.adapter).__name__
            raise RuntimeError(
                f"Configured adapter ('{adapter_type}') lacks 'add' method."
            )
        # Avoid running add with the placeholder adapter.
        if isinstance(self.adapter, RagTool._AdapterPlaceholder):
            raise RuntimeError(
                "RAG adapter placeholder detected. Tool not fully initialized."
            )

        # Use backticks for table name safety, though loader might handle it.
        query = f"SELECT * FROM `{target_table}`;"

        try:
            self.adapter.add(query, data_type="mysql", loader=self._mysql_loader)

            # Mark initial data as loaded if this was the default table.
            if target_table == self.table_name:
                self._initial_data_added = True
        except NotImplementedError:
            # Should be caught by hasattr, but handle explicit raise.
            adapter_type = type(self.adapter).__name__
            raise RuntimeError(
                f"Adapter '{adapter_type}' claims 'add' but did not implement it."
            ) from None
        except Exception as e:
            raise RuntimeError(
                f"Failed to add data from table '{target_table}': {e}"
            ) from e

    def _run(self, search_query: str) -> str:
        """
        Executes the semantic search query against the configured table.

        Loads data lazily on the first call if it hasn't been loaded yet.

        Args:
            search_query: The semantic query string.

        Returns:
            A string containing the relevant search results, or an
            error message if the search fails.

        Raises:
            RuntimeError: If lazy loading of initial data fails.
            NotImplementedError: If the adapter's query method isn't implemented.
        """
        # --- Lazy Loading ---
        if not self._initial_data_added:
            try:
                # Load data for the table configured at initialization.
                self.add()
            except Exception as e:
                # If lazy loading fails, the tool cannot proceed reliably.
                raise RuntimeError(
                    f"Failed to automatically load initial data for table "
                    f"'{self.table_name}' before query execution: {e}"
                ) from e

        try:
            # Delegate to the RagTool's run method, which uses the adapter.
            result = super()._run(query=search_query)

            return f"Relevant Content:\n\n{result}"
        except NotImplementedError:
            adapter_type = type(self.adapter).__name__
            raise NotImplementedError(
                f"The configured RAG adapter ('{adapter_type}') does not "
                f"implement the required 'query' method."
            )
        except Exception as e:
            # Provide a user-friendly error message for search failures.
            return (
                f"Error executing search query '{search_query}' on table "
                f"'{self.table_name}'. Failed to retrieve results. "
                f"Details: {type(e).__name__}"  # Avoid leaking full error details
            )

Additional context

A couple more thoughts:

  1. Pydantic & Verbosity: Yeah, the refactored code is way more verbose than the original attempt. That's the cost of leveraging Pydantic properly for validation and type safety. It feels like leaning into CrewAI's apparent design choice, trading brevity for long-term reliability and maintainability.
  2. Error Messages: Explicitly catching parsing/validation errors and raising PydanticCustomError or RuntimeError with clearer messages is crucial. Pydantic's default errors can sometimes be cryptic, so improving the DX here helps users debug better.
  3. The Bigger RAG Picture (Chunking Tabular Data): This fix addresses the tool's initialization and configuration, but there's a potential downstream issue for all RAG tools handling tabular data (MySQL, CSV, etc.). When you SELECT * and just chunk the resulting rows (e.g., Peter,40,89,Lucy\nSusan,35,11,Buddy), later chunks lose the context of which value belongs to which column (Name, Age, ID, Pet). This severely limits the LLM's ability to answer questions accurately (like "What is Susan's Pet?"). I discussed this regarding the CSVSearchTool here. A potential future enhancement for MySQLSearchTool (and others) could be to transform each fetched database row into a JSON object ({'name': 'Peter', 'age': 40, ...}) before passing it to the Embedchain chunker. This would preserve the column context within each chunk, likely leading to much better retrieval results. It's probably beyond the scope of this immediate bugfix, but definitely worth considering for the next evolution of these tools.

mouramax avatar Apr 25 '25 21:04 mouramax