railtracks.guardrails.llm
class
LLMGuardrailsMixin:
24class LLMGuardrailsMixin: 25 """ 26 Mixin for nodes that invoke an LLM. Overrides _pre_invoke and _post_invoke to run 27 input and output guardrails. Set guardrails= when building the node. 28 """ 29 30 guardrails: Guard | None = None 31 _details: dict[str, Any] 32 llm_model: ModelBase 33 uuid: str 34 name: Callable[[], str] 35 36 def _append_guard_traces(self, traces: list[GuardrailTrace]) -> None: 37 if not traces: 38 return 39 self._details["guard_details"].extend(traces) 40 41 def _guardrail_agent_kind(self) -> str: 42 cls_name = self.__class__.__name__.lower() 43 if "structured" in cls_name: 44 return "structured" 45 if "terminal" in cls_name: 46 return "terminal" 47 return "llm" 48 49 def _resolve_model_metadata(self) -> tuple[str | None, str | None]: 50 model_name = getattr(self.llm_model, "model_name", None) 51 if callable(model_name): 52 model_name = model_name() 53 model_provider = getattr(self.llm_model, "model_provider", None) 54 if callable(model_provider): 55 model_provider = model_provider() 56 return ( 57 cast(str | None, model_name), 58 str(model_provider) if model_provider is not None else None, 59 ) 60 61 def _build_input_event(self, context: Any) -> LLMGuardrailEvent: 62 """Build LLMGuardrailEvent for input phase from context (MessageHistory).""" 63 model_name, model_provider = self._resolve_model_metadata() 64 return LLMGuardrailEvent( 65 phase=LLMGuardrailPhase.INPUT, 66 messages=context, 67 node_name=self.name(), 68 node_uuid=self.uuid, 69 model_name=model_name, 70 model_provider=model_provider, 71 tags={"agent_kind": self._guardrail_agent_kind()}, 72 ) 73 74 def _build_output_event( 75 self, context: Any, assistant_message: Message 76 ) -> LLMGuardrailEvent: 77 """Build LLMGuardrailEvent for output phase: context is message history; assistant_message is this turn's output.""" 78 model_name, model_provider = self._resolve_model_metadata() 79 return LLMGuardrailEvent( 80 phase=LLMGuardrailPhase.OUTPUT, 81 messages=context, 82 output_message=assistant_message, 83 node_name=self.name(), 84 node_uuid=self.uuid, 85 model_name=model_name, 86 model_provider=model_provider, 87 tags={"agent_kind": self._guardrail_agent_kind()}, 88 ) 89 90 def _pre_invoke(self, context: Any) -> Any: 91 if self.guardrails is None or not self.guardrails.input: 92 return context 93 event = self._build_input_event(context) 94 new_context, traces, decision = GuardRunner(self.guardrails).run_llm_input( 95 event 96 ) 97 self._append_guard_traces(traces) 98 if decision is not None and decision.action == GuardrailAction.BLOCK: 99 rail_name = traces[-1].rail_name if traces else None 100 raise GuardrailBlockedError( 101 rail_name=rail_name, 102 reason=decision.reason, 103 user_facing_message=decision.user_facing_message, 104 traces=traces, 105 meta=decision.meta, 106 ) 107 108 return new_context 109 110 def _post_invoke(self, context: Any, result: Any) -> Any: 111 if self.guardrails is None or not self.guardrails.output: 112 return result 113 if not isinstance(result, Response): 114 return result 115 event = self._build_output_event(context, result.message) 116 new_message, traces, decision = GuardRunner(self.guardrails).run_llm_output( 117 event, result.message 118 ) 119 self._append_guard_traces(traces) 120 if decision is not None and decision.action == GuardrailAction.BLOCK: 121 rail_name = traces[-1].rail_name if traces else None 122 raise GuardrailBlockedError( 123 rail_name=rail_name, 124 reason=decision.reason, 125 user_facing_message=decision.user_facing_message, 126 traces=traces, 127 meta=decision.meta, 128 ) 129 130 return Response(message=new_message, message_info=result.message_info)
Mixin for nodes that invoke an LLM. Overrides _pre_invoke and _post_invoke to run input and output guardrails. Set guardrails= when building the node.
llm_model: railtracks.llm.ModelBase