sparkdantic icon indicating copy to clipboard operation
sparkdantic copied to clipboard

Fix forward references and circular dependencies in Pydantic models

Open rjurney opened this issue 7 months ago • 4 comments

Summary

  • Fixes issue #798 where BAML-generated models with forward references fail to convert to PySpark schemas
  • Adds proper handling for circular references to prevent infinite recursion
  • Ensures enums in nested models are properly converted to string types

Changes

  1. Forward Reference Resolution: Added model_rebuild() call to resolve forward references before processing
  2. ForwardRef Type Handling: Explicitly handle ForwardRef types, defaulting them to string type when unresolved
  3. Type Safety: Added inspect.isclass() check before issubclass() to prevent TypeError with non-class types
  4. Circular Reference Detection: Implemented visited models tracking to prevent infinite recursion in self-referential models
  5. Specific Exception Handling: Catch specific Pydantic exceptions (PydanticUndefinedAnnotation, PydanticSchemaGenerationError) instead of generic Exception

Test plan

Added comprehensive test suite (tests/test_forward_references.py) covering:

  • [x] Models with forward references
  • [x] Enums in nested models
  • [x] Undefined forward references
  • [x] Circular references (self-referential models)

Also included integration test (test_forward_ref_fix.py) demonstrating the fix works with real BAML-generated Company/Ticker models.

All tests pass successfully.

Related Issues

Fixes #798

rjurney avatar Sep 11 '25 21:09 rjurney

Thanks for identifying an issue and raising a PR to address it, @rjurney!

There appear to be three main changes in this PR:

  1. Resolve forward references before schema generation
  2. Fix recursive (self-referencing) models
  3. Fix string Enum

On 1, if a model's forward references can be resolved via BaseModel.model_rebuild, is this change required? If a client expects forward refs, can they just call this method before creating the Spark schema? Correct me if I’m wrong, but this solves the original problem in the issue you raised.

On 2 (related to 1 due to forward references), unbounded recursive models don’t make much sense to me w.r.t generating structured Spark schemas. A recursive model could either map to an unstructured type (i.e. StringType, I like your choice here 👍) or a semi-structured type e.g. VariantType. I'm reluctant to handle this scenario by extending create_json_spark_schema with a private function parameter. Happy to discuss, design, and reach agreement in a separate issue on this. Additionally, do you currently face this (self-reference) problem in your original issue?

On 3, I’m not sure if there was a problem with the existing code. After rebuilding the model in your original issue, did the string Enum field still fail schema generation?

mitchstockdale avatar Sep 14 '25 03:09 mitchstockdale

For one thing, you have to instantiate an object and run that method by calling model rebuild, which isn’t the published class based API and will trip many people up. The recursive stuff I can remove, Claude did that. As to Enumerate, yes your string name of the class was the issue with a string enum. It wasn’t pulling the class.

rjurney avatar Sep 14 '25 03:09 rjurney

The recursive stuff I can remove, Claude did that.

Thanks 👍

For one thing, you have to instantiate an object and run that method by calling model rebuild, which isn’t the published class based API and will trip many people up... As to Enumerate, yes your string name of the class was the issue with a string enum. It wasn’t pulling the class.

I see, it's because the Enum is also a forward reference (not quite what the PR summary suggests). Thanks for clarifying the issue.

BaseModel.model_rebuild is a classmethod, is quite well documented, and tries to rebuild the schema. It doesn't require model instantiation. Using a similar example to yours, the workaround for this is:

from enum import StrEnum
from typing import Optional
from pydantic import BaseModel
from sparkdantic import SparkModel

class Parent(BaseModel):
    child: Optional["Child"] = None

class Child(BaseModel):
    bar: Optional["Bar"] = None

class Bar(StrEnum):
    FOO = "FOO"

class SparkParent(Parent, SparkModel):
    pass

Child.model_rebuild()
SparkParent.model_rebuild(force=True)
SparkParent.model_json_spark_schema()
# `{'type': 'struct', 'fields': [{'name': 'child', 'type': {'type': 'struct', 'fields': [{'name': 'bar', 'type': 'string', 'nullable': True, 'metadata': {}}]}, 'nullable': True, 'metadata': {}}]}`

Note:

  • Unfortunately, it doesn't seem possible to resolve forward references for recursive models (i.e. just call model_rebuild on SparkParent, not Child).
  • If Bar and Child were defined in the right order, model_rebuild is not required.

I'm aware that this might require a complex workaround if the hierarchy of your models is unknown at runtime, which would be the case for generated models.

My strong preference would be that, given that model_rebuild is already provided by the pydantic BaseModel, we should encourage it's use. If a ForwardRef field is encountered, we could instead raise an exception with a message that includes:

  • model name; and
  • model field; and
  • suggests the client should call model_rebuild to fix the forward references before Spark schema generation; or
  • suggests the client use spark_type override

I'm reluctant to default ForwardRef fields to string type as is done here. In the above example, the schema would be: {'type': 'struct', 'fields': [{'name': 'child', 'type': 'string', 'nullable': True, 'metadata': {}}]}

I would prefer this to be an explicit option/flag for clients instead, either in the field definition (as an override) or another parameter in create_json_spark_schema.

mitchstockdale avatar Sep 15 '25 23:09 mitchstockdale

Let me double check, but the only way I could use the model at all was to subclass both my model and SparkModel, instantiate it, then serialize the schema.

rjurney avatar Sep 16 '25 04:09 rjurney