Skip to content
29 changes: 24 additions & 5 deletions python/flink_agents/api/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
# limitations under the License.
#################################################################################
from abc import ABC
from typing import Any
from typing import Any, Dict

try:
from typing import override
except ImportError:
from typing_extensions import override
from uuid import UUID, uuid4

from pydantic import BaseModel, Field, model_validator
from pydantic_core import PydanticSerializationError
from pyflink.common import Row


Expand All @@ -35,13 +41,26 @@ class Event(BaseModel, ABC, extra="allow"):

id: UUID = Field(default_factory=uuid4)

@staticmethod
def __serialize_unknown(field: Any) -> Dict[str, Any]:
"""Handle serialization of unknown types, specifically Row objects."""
if isinstance(field, Row):
return {"type": "Row", "values": field._values}
else:
err_msg = f"Unable to serialize unknown type: {field.__class__}"
raise PydanticSerializationError(err_msg)

@override
def model_dump_json(self, **kwargs: Any) -> str:
"""Override model_dump_json to handle Row objects using fallback."""
# Set fallback if not provided in kwargs
if 'fallback' not in kwargs:
kwargs['fallback'] = self.__serialize_unknown
return super().model_dump_json(**kwargs)

@model_validator(mode="after")
def validate_extra(self) -> "Event":
"""Ensure init fields is serializable."""
# TODO: support Event contains Row field be json serializable
for value in self.model_dump().values():
if isinstance(value, Row):
return self
self.model_dump_json()
return self

Expand Down
77 changes: 76 additions & 1 deletion python/flink_agents/api/tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
from typing import Type
from typing import Any, Type

import pytest
from pydantic import ValidationError
Expand Down Expand Up @@ -47,3 +47,78 @@ def test_event_setattr_non_serializable() -> None: # noqa D103

def test_input_event_ignore_row_unserializable() -> None: # noqa D103
InputEvent(input=Row({"a": 1}))


def test_event_row_with_non_serializable_fails() -> None: # noqa D103
with pytest.raises(ValidationError):
Event(row_field=Row({"a": 1}), non_serializable_field=Type[InputEvent])


def test_event_multiple_rows_serializable() -> None: # noqa D103
Event(row1=Row({"a": 1}), row2=Row({"b": 2}), normal_field="test")


def test_event_setattr_row_serializable() -> None: # noqa D103
event = Event(a=1)
event.row_field = Row({"key": "value"})


def test_event_json_serialization_with_row() -> None: # noqa D103
event = InputEvent(input=Row({"test": "data"}))
json_str = event.model_dump_json()
assert "test" in json_str
assert "Row" in json_str


def test_efficient_row_serialization_with_fallback() -> None:
"""Test that the new fallback-based serialization works efficiently."""
row_data = {"a": 1, "b": "test", "c": [1, 2, 3]}
event = InputEvent(input=Row(row_data))

json_str = event.model_dump_json()
import json
parsed = json.loads(json_str)

assert parsed["input"]["type"] == "Row"
assert parsed["input"]["values"] == [row_data]
assert "id" in parsed # UUID should be present

def custom_fallback(obj: Any) -> dict[str, Any]:
if isinstance(obj, Row):
return {"custom_type": "CustomRow", "data": obj._values}
msg = "Unknown type"
raise ValueError(msg)

custom_json = event.model_dump_json(fallback=custom_fallback)
custom_parsed = json.loads(custom_json)

assert custom_parsed["input"]["custom_type"] == "CustomRow"
assert custom_parsed["input"]["data"] == [row_data]


def test_event_with_mixed_serializable_types() -> None:
"""Test event with mix of normal and Row types."""
event = InputEvent(input={
"normal_data": {"key": "value"},
"row_data": Row({"test": "data"}),
"list_data": [1, 2, 3],
"nested_row": {"inner": Row({"nested": True})}
})

json_str = event.model_dump_json()

import json
parsed = json.loads(json_str)

# Normal data should be serialized normally
assert parsed["input"]["normal_data"]["key"] == "value"
assert parsed["input"]["list_data"] == [1, 2, 3]

# Row data should use fallback serializer
assert parsed["input"]["row_data"]["type"] == "Row"
assert parsed["input"]["nested_row"]["inner"]["type"] == "Row"


def test_input_event_ignore_row_unserializable() -> None: # noqa D103
InputEvent(input=Row({"a": 1}))

Loading