Skip to content

llm_invocation_stage

llm_invocation_stage

LLM invocation stage with concurrency and retry logic.

LLMInvocationStage

LLMInvocationStage(llm_client: LLMClient, concurrency: int = 5, rate_limiter: RateLimiter | None = None, retry_handler: RetryHandler | None = None, error_policy: ErrorPolicy = ErrorPolicy.SKIP, max_retries: int = 3)

Bases: PipelineStage[list[PromptBatch], list[ResponseBatch]]

Invoke LLM with prompts using concurrency and retries.

Responsibilities: - Execute LLM calls with rate limiting - Handle retries for transient failures - Track tokens and costs - Support concurrent processing

Initialize LLM invocation stage.

Parameters:

Name Type Description Default
llm_client LLMClient

LLM client instance

required
concurrency int

Max concurrent requests

5
rate_limiter RateLimiter | None

Optional rate limiter

None
retry_handler RetryHandler | None

Optional retry handler

None
error_policy ErrorPolicy

Policy for handling errors

SKIP
max_retries int

Maximum retry attempts

3
Source code in ondine/stages/llm_invocation_stage.py
def __init__(
    self,
    llm_client: LLMClient,
    concurrency: int = 5,
    rate_limiter: RateLimiter | None = None,
    retry_handler: RetryHandler | None = None,
    error_policy: ErrorPolicy = ErrorPolicy.SKIP,
    max_retries: int = 3,
):
    """
    Initialize LLM invocation stage.

    Args:
        llm_client: LLM client instance
        concurrency: Max concurrent requests
        rate_limiter: Optional rate limiter
        retry_handler: Optional retry handler
        error_policy: Policy for handling errors
        max_retries: Maximum retry attempts
    """
    super().__init__("LLMInvocation")
    self.llm_client = llm_client
    self.concurrency = concurrency
    self.rate_limiter = rate_limiter
    self.retry_handler = retry_handler or RetryHandler()
    self.error_handler = ErrorHandler(
        policy=error_policy,
        max_retries=max_retries,
        default_value_factory=lambda: LLMResponse(
            text="",
            tokens_in=0,
            tokens_out=0,
            model=llm_client.model,
            cost=Decimal("0.0"),
            latency_ms=0.0,
        ),
    )

process

process(batches: list[PromptBatch], context: Any) -> list[ResponseBatch]

Execute LLM calls for all prompt batches.

Source code in ondine/stages/llm_invocation_stage.py
def process(self, batches: list[PromptBatch], context: Any) -> list[ResponseBatch]:
    """Execute LLM calls for all prompt batches."""
    response_batches: list[ResponseBatch] = []

    for _batch_idx, batch in enumerate(batches):
        self.logger.info(
            f"Processing batch {batch.batch_id} ({len(batch.prompts)} prompts)"
        )

        # Process batch with concurrency
        responses = self._process_batch_concurrent(batch.prompts, context)

        # Notify progress after each batch
        if hasattr(context, "notify_progress"):
            context.notify_progress()

        # Calculate batch metrics
        total_tokens = sum(r.tokens_in + r.tokens_out for r in responses)
        total_cost = sum(r.cost for r in responses)
        latencies = [r.latency_ms for r in responses]

        # Create response batch
        response_batch = ResponseBatch(
            responses=[r.text for r in responses],
            metadata=batch.metadata,
            tokens_used=total_tokens,
            cost=total_cost,
            batch_id=batch.batch_id,
            latencies_ms=latencies,
        )
        response_batches.append(response_batch)

        # Update context
        context.add_cost(total_cost, total_tokens)
        context.update_row(batch.metadata[-1].row_index if batch.metadata else 0)

    return response_batches

validate_input

validate_input(batches: list[PromptBatch]) -> ValidationResult

Validate prompt batches.

Source code in ondine/stages/llm_invocation_stage.py
def validate_input(self, batches: list[PromptBatch]) -> ValidationResult:
    """Validate prompt batches."""
    result = ValidationResult(is_valid=True)

    if not batches:
        result.add_error("No prompt batches provided")

    for batch in batches:
        if not batch.prompts:
            result.add_error(f"Batch {batch.batch_id} has no prompts")

        if len(batch.prompts) != len(batch.metadata):
            result.add_error(f"Batch {batch.batch_id} prompt/metadata mismatch")

    return result

estimate_cost

estimate_cost(batches: list[PromptBatch]) -> CostEstimate

Estimate LLM invocation cost.

Source code in ondine/stages/llm_invocation_stage.py
def estimate_cost(self, batches: list[PromptBatch]) -> CostEstimate:
    """Estimate LLM invocation cost."""
    total_input_tokens = 0
    total_output_tokens = 0

    # Estimate tokens for all prompts
    for batch in batches:
        for prompt in batch.prompts:
            input_tokens = self.llm_client.estimate_tokens(prompt)
            total_input_tokens += input_tokens

            # Assume average output length (can be made configurable)
            estimated_output = int(input_tokens * 0.5)
            total_output_tokens += estimated_output

    total_cost = self.llm_client.calculate_cost(
        total_input_tokens, total_output_tokens
    )

    return CostEstimate(
        total_cost=total_cost,
        total_tokens=total_input_tokens + total_output_tokens,
        input_tokens=total_input_tokens,
        output_tokens=total_output_tokens,
        rows=sum(len(b.prompts) for b in batches),
        confidence="estimate",
    )