railtracks.guardrails.llm

1from . import input, output
2from .mixin import LLMGuardrailsMixin
3
4__all__ = ["input", "output", "LLMGuardrailsMixin"]
class LLMGuardrailsMixin:
 24class LLMGuardrailsMixin:
 25    """
 26    Mixin for nodes that invoke an LLM. Overrides _pre_invoke and _post_invoke to run
 27    input and output guardrails. Set guardrails= when building the node.
 28    """
 29
 30    guardrails: Guard | None = None
 31    _details: dict[str, Any]
 32    llm_model: ModelBase
 33    uuid: str
 34    name: Callable[[], str]
 35
 36    def _append_guard_traces(self, traces: list[GuardrailTrace]) -> None:
 37        if not traces:
 38            return
 39        self._details["guard_details"].extend(traces)
 40
 41    def _guardrail_agent_kind(self) -> str:
 42        cls_name = self.__class__.__name__.lower()
 43        if "structured" in cls_name:
 44            return "structured"
 45        if "terminal" in cls_name:
 46            return "terminal"
 47        return "llm"
 48
 49    def _resolve_model_metadata(self) -> tuple[str | None, str | None]:
 50        model_name = getattr(self.llm_model, "model_name", None)
 51        if callable(model_name):
 52            model_name = model_name()
 53        model_provider = getattr(self.llm_model, "model_provider", None)
 54        if callable(model_provider):
 55            model_provider = model_provider()
 56        return (
 57            cast(str | None, model_name),
 58            str(model_provider) if model_provider is not None else None,
 59        )
 60
 61    def _build_input_event(self, context: Any) -> LLMGuardrailEvent:
 62        """Build LLMGuardrailEvent for input phase from context (MessageHistory)."""
 63        model_name, model_provider = self._resolve_model_metadata()
 64        return LLMGuardrailEvent(
 65            phase=LLMGuardrailPhase.INPUT,
 66            messages=context,
 67            node_name=self.name(),
 68            node_uuid=self.uuid,
 69            model_name=model_name,
 70            model_provider=model_provider,
 71            tags={"agent_kind": self._guardrail_agent_kind()},
 72        )
 73
 74    def _build_output_event(
 75        self, context: Any, assistant_message: Message
 76    ) -> LLMGuardrailEvent:
 77        """Build LLMGuardrailEvent for output phase: context is message history; assistant_message is this turn's output."""
 78        model_name, model_provider = self._resolve_model_metadata()
 79        return LLMGuardrailEvent(
 80            phase=LLMGuardrailPhase.OUTPUT,
 81            messages=context,
 82            output_message=assistant_message,
 83            node_name=self.name(),
 84            node_uuid=self.uuid,
 85            model_name=model_name,
 86            model_provider=model_provider,
 87            tags={"agent_kind": self._guardrail_agent_kind()},
 88        )
 89
 90    def _pre_invoke(self, context: Any) -> Any:
 91        if self.guardrails is None or not self.guardrails.input:
 92            return context
 93        event = self._build_input_event(context)
 94        new_context, traces, decision = GuardRunner(self.guardrails).run_llm_input(
 95            event
 96        )
 97        self._append_guard_traces(traces)
 98        if decision is not None and decision.action == GuardrailAction.BLOCK:
 99            rail_name = traces[-1].rail_name if traces else None
100            raise GuardrailBlockedError(
101                rail_name=rail_name,
102                reason=decision.reason,
103                user_facing_message=decision.user_facing_message,
104                traces=traces,
105                meta=decision.meta,
106            )
107
108        return new_context
109
110    def _post_invoke(self, context: Any, result: Any) -> Any:
111        if self.guardrails is None or not self.guardrails.output:
112            return result
113        if not isinstance(result, Response):
114            return result
115        event = self._build_output_event(context, result.message)
116        new_message, traces, decision = GuardRunner(self.guardrails).run_llm_output(
117            event, result.message
118        )
119        self._append_guard_traces(traces)
120        if decision is not None and decision.action == GuardrailAction.BLOCK:
121            rail_name = traces[-1].rail_name if traces else None
122            raise GuardrailBlockedError(
123                rail_name=rail_name,
124                reason=decision.reason,
125                user_facing_message=decision.user_facing_message,
126                traces=traces,
127                meta=decision.meta,
128            )
129
130        return Response(message=new_message, message_info=result.message_info)

Mixin for nodes that invoke an LLM. Overrides _pre_invoke and _post_invoke to run input and output guardrails. Set guardrails= when building the node.

guardrails: railtracks.guardrails.Guard | None = None
uuid: str
name: Callable[[], str]