| """Bio-Experiment Planning Environment.
|
|
|
| Implements the OpenEnv ``Environment`` interface as a POMDP where the
|
| agent proposes one structured experiment / analysis step at a time and
|
| receives simulated intermediate outputs from a latent biological world.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from typing import Any, Dict, List, Optional
|
| from uuid import uuid4
|
|
|
| from openenv.core.env_server.interfaces import Environment
|
| from openenv.core.env_server.types import State
|
|
|
| from models import (
|
| ActionType,
|
| ConclusionClaim,
|
| ExperimentAction,
|
| ExperimentObservation,
|
| IntermediateOutput,
|
| PipelineStepRecord,
|
| ResourceUsage,
|
| TaskSpec,
|
| )
|
|
|
| from server.rules.engine import RuleEngine
|
| from server.rewards.reward import RewardBreakdown, RewardComputer
|
| from server.simulator.latent_state import FullLatentState
|
| from server.simulator.noise import NoiseModel
|
| from server.simulator.transition import ACTION_COSTS, TransitionEngine, compute_action_cost
|
| from server.tasks.generator import TaskGenerator
|
|
|
|
|
| MAX_STEPS = 30
|
|
|
|
|
| class BioExperimentEnvironment(Environment):
|
| """POMDP environment for iterative biological experiment planning.
|
|
|
| The agent observes ``ExperimentObservation`` (partial view) while the
|
| environment maintains a ``FullLatentState`` (hidden ground truth).
|
| """
|
|
|
| SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
|
|
| def __init__(
|
| self,
|
| scenario_name: Optional[str] = None,
|
| *,
|
| domain_randomise: bool = True,
|
| ) -> None:
|
| self._state = State(episode_id=str(uuid4()), step_count=0)
|
| self._latent: Optional[FullLatentState] = None
|
| self._task: Optional[TaskSpec] = None
|
| self._scenario_name = scenario_name
|
| self._noise = NoiseModel()
|
| self._engine = TransitionEngine(self._noise)
|
| self._rules = RuleEngine()
|
| self._rewards = RewardComputer()
|
| self._task_gen = TaskGenerator(domain_randomise=domain_randomise)
|
|
|
| self._history: List[PipelineStepRecord] = []
|
| self._outputs: List[IntermediateOutput] = []
|
| self._conclusions: List[ConclusionClaim] = []
|
| self._subagent_outputs: List[Dict[str, Any]] = []
|
| self._discovered_markers: List[str] = []
|
| self._candidate_mechanisms: List[str] = []
|
| self._cumulative_reward: float = 0.0
|
|
|
|
|
|
|
| def reset(self, seed: Optional[int] = None) -> ExperimentObservation:
|
| seed = seed if seed is not None else hash(uuid4()) % (2**31)
|
| self._noise.reseed(seed)
|
| self._state = State(episode_id=str(uuid4()), step_count=0)
|
|
|
| self._task, self._latent = self._task_gen.generate(
|
| seed=seed,
|
| scenario_name=self._scenario_name,
|
| )
|
| self._latent.rng_seed = seed
|
|
|
| self._history.clear()
|
| self._outputs.clear()
|
| self._conclusions.clear()
|
| self._subagent_outputs.clear()
|
| self._discovered_markers.clear()
|
| self._candidate_mechanisms.clear()
|
| self._cumulative_reward = 0.0
|
|
|
| return self._build_observation(reward=0.0, done=False)
|
|
|
| def step(
|
| self, action: ExperimentAction
|
| ) -> ExperimentObservation:
|
| assert self._latent is not None, "Call reset() before step()"
|
| assert self._task is not None
|
|
|
| self._state.step_count += 1
|
| prev_state = self._latent.model_copy(deep=True)
|
|
|
| violations = self._rules.check(action, self._latent)
|
| hard_v = self._rules.hard_violations(violations)
|
| soft_v = self._rules.soft_violations(violations)
|
|
|
| result = self._engine.step(
|
| self._latent,
|
| action,
|
| hard_violations=hard_v,
|
| soft_violations=soft_v,
|
| )
|
| self._latent = result.next_state
|
|
|
| step_rb = self._rewards.step_reward(
|
| action, prev_state, self._latent, result.output, hard_v, soft_v,
|
| )
|
|
|
| cost_budget, cost_time = compute_action_cost(action)
|
| self._history.append(PipelineStepRecord(
|
| step_index=self._state.step_count,
|
| action_type=action.action_type,
|
| method=action.method,
|
| parameters=action.parameters,
|
| output_summary=result.output.summary,
|
| output_type=result.output.output_type,
|
| success=result.output.success,
|
| quality_score=result.output.quality_score,
|
| resource_cost=cost_budget,
|
| time_cost_days=cost_time,
|
| ))
|
| self._outputs.append(result.output)
|
| self._update_discoveries(action, result.output)
|
|
|
| if action.action_type == ActionType.SYNTHESIZE_CONCLUSION and result.output.success:
|
| raw_claims = action.parameters.get("claims", [])
|
| for c in raw_claims:
|
| if isinstance(c, dict):
|
| self._conclusions.append(ConclusionClaim(**c))
|
|
|
| done = result.done or self._state.step_count >= MAX_STEPS
|
|
|
| terminal_rb = RewardBreakdown()
|
| if done:
|
| terminal_rb = self._rewards.terminal_reward(
|
| self._latent,
|
| self._conclusions,
|
| self._task.success_criteria,
|
| discovered_markers=self._discovered_markers,
|
| candidate_mechanisms=self._candidate_mechanisms,
|
| )
|
|
|
| total_reward = step_rb.total + terminal_rb.total
|
| self._cumulative_reward += total_reward
|
|
|
| breakdown = step_rb.to_dict()
|
| breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()})
|
|
|
| return self._build_observation(
|
| reward=total_reward,
|
| done=done,
|
| latest_output=result.output,
|
| rule_violations=hard_v + soft_v,
|
| reward_breakdown=breakdown,
|
| metadata_extra={"reward_breakdown": breakdown},
|
| )
|
|
|
| @property
|
| def state(self) -> State:
|
| return self._state
|
|
|
| def set_scenario(self, scenario_name: Optional[str]) -> None:
|
| """Set the scenario used on the next reset."""
|
|
|
| self._scenario_name = scenario_name
|
|
|
|
|
|
|
| def _build_observation(
|
| self,
|
| *,
|
| reward: float,
|
| done: bool,
|
| latest_output: Optional[IntermediateOutput] = None,
|
| rule_violations: Optional[List[str]] = None,
|
| reward_breakdown: Optional[Dict[str, float]] = None,
|
| metadata_extra: Optional[Dict[str, Any]] = None,
|
| ) -> ExperimentObservation:
|
| assert self._task is not None
|
| assert self._latent is not None
|
| res = self._latent.resources
|
| meta: Dict[str, Any] = {
|
| "episode_id": self._state.episode_id,
|
| "step": self._state.step_count,
|
| "cumulative_reward": self._cumulative_reward,
|
| }
|
| if metadata_extra:
|
| meta.update(metadata_extra)
|
| return ExperimentObservation(
|
| task=self._task,
|
| step_index=self._state.step_count,
|
| pipeline_history=list(self._history),
|
| available_assays=list(self._task.available_assays),
|
| available_tools=list(self._task.available_tools),
|
| resource_usage=ResourceUsage(
|
| budget_used=res.budget_used,
|
| budget_remaining=res.budget_remaining,
|
| time_used_days=res.time_used_days,
|
| time_remaining_days=res.time_remaining_days,
|
| samples_consumed=res.samples_consumed,
|
| compute_hours_used=res.compute_hours_used,
|
| ),
|
| latest_output=latest_output,
|
| all_outputs=list(self._outputs),
|
| discovered_markers=list(self._discovered_markers),
|
| candidate_mechanisms=list(self._candidate_mechanisms),
|
| uncertainty_summary=self._compute_uncertainty_summary(),
|
| subagent_outputs=list(self._subagent_outputs),
|
| conclusions=list(self._conclusions),
|
| rule_violations=rule_violations or [],
|
| step_reward_breakdown=reward_breakdown or {},
|
| done=done,
|
| reward=reward,
|
| metadata=meta,
|
| )
|
|
|
| def _compute_uncertainty_summary(self) -> Dict[str, float]:
|
| if not self._outputs:
|
| return {}
|
| recent = self._outputs[-5:]
|
| avg_unc = sum(o.uncertainty for o in recent) / len(recent)
|
| avg_qual = sum(o.quality_score for o in recent) / len(recent)
|
| return {"avg_uncertainty": avg_unc, "avg_quality": avg_qual}
|
|
|
| def _update_discoveries(
|
| self, action: ExperimentAction, output: IntermediateOutput
|
| ) -> None:
|
| if action.action_type == ActionType.MARKER_SELECTION:
|
| markers = output.data.get("markers", [])
|
| existing = set(self._discovered_markers)
|
| for m in markers:
|
| if m not in existing:
|
| self._discovered_markers.append(m)
|
| existing.add(m)
|
| if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
|
| regs = output.data.get("top_regulators", [])
|
| existing = set(self._candidate_mechanisms)
|
| for r in regs:
|
| if r not in existing:
|
| self._candidate_mechanisms.append(r)
|
| existing.add(r)
|
| if action.action_type == ActionType.PATHWAY_ENRICHMENT:
|
| pathways = output.data.get("top_pathways", [])
|
| existing = set(self._candidate_mechanisms)
|
| for p in pathways:
|
| if isinstance(p, dict) and p["pathway"] not in existing:
|
| self._candidate_mechanisms.append(p["pathway"])
|
| existing.add(p["pathway"])
|
|
|