'''
PsiJ plugin for ORBIT — HPC job submission.
Three-class pattern
-------------------
PSIJSession Endpoint-side session: holds one PsiJ ``Executor`` per submit call,
manages job state via callbacks and background polling, streams
stdout/stderr incrementally.
PSIJClient Application-side thin HTTP wrapper: delegates to the endpoint service
over the bridge (``submit_job``, ``get_job_status``, ``list_jobs``,
``cancel_job``, ``submit_tunneled``, ``tunnel_status``).
PluginPSIJ Registers the plugin with the endpoint, adds URL routes, and wires
requests to the correct PSIJSession via ``_forward()``.
'''
import asyncio
import logging
import os
import pathlib
import shutil
import socket
import time
from datetime import timedelta
from typing import Any, Dict
from fastapi import FastAPI, HTTPException, Request
import psij
from .plugin_base import Plugin
from .plugin_session_base import PluginSession
from .client import PluginClient
from .tunnel import relay_dir as _relay_dir
log = logging.getLogger("radical.orbit")
# Default poll interval for job status updates (in seconds)
PSIJ_POLL_INTERVAL = 5.0
# consecutive STATE_UNKNOWN polls tolerated after the job has been seen
# RUNNING — beyond this we conclude the job left the queue and bail
UNKNOWN_TOLERANCE = 3
# Persistent directory for job stdout/stderr capture
_OUTPUT_BASE = pathlib.Path.home() / '.radical' / 'orbit' / 'psij' / 'output'
# Maximum age (days) for stale output directories cleaned up on session creation
_OUTPUT_MAX_AGE_DAYS = 7
# Diagnostic: when set (truthy env var), pass ``keep_files=True`` into the
# batch-scheduler executor config so PsiJ leaves its generated submit
# scripts under ``~/.psij/work/<scheduler>/`` for inspection. Off by
# default to keep the workdir tidy.
_KEEP_PSIJ_FILES = os.environ.get('RADICAL_ORBIT_PSIJ_KEEP_FILES', '').lower() \
in ('1', 'true', 'yes', 'on')
# Terminal states that don't need further polling
TERMINAL_STATES = {'COMPLETED', 'FAILED', 'CANCELED'}
def _normalize_state(state) -> str:
"""Normalize a PsiJ JobState to a plain string (strip 'JobState.' prefix)."""
s = str(state)
return s[9:] if s.startswith('JobState.') else s
def _read_output_file(job, attr: str, offset: int = 0) -> str:
"""Read stdout or stderr from a job's spec path attribute.
Args:
job: PsiJ job object.
attr: Attribute name on job.spec ('stdout_path' or 'stderr_path').
offset: Byte offset to start reading from (0 = full file).
Returns:
Content read from the file starting at offset.
"""
try:
path = getattr(job.spec, attr, None)
if path and os.path.exists(str(path)):
with open(str(path), 'r') as f:
if offset > 0:
f.seek(offset)
return f.read()
except Exception as e:
log.debug("Failed to read %s for job: %s", attr, e)
return ""
def _output_file_size(job, attr: str) -> int:
"""Return the byte size of a job's stdout/stderr file, or 0."""
try:
path = getattr(job.spec, attr, None)
if path and os.path.exists(str(path)):
return os.path.getsize(str(path))
except Exception:
pass
return 0
[docs]
class PSIJSession(PluginSession):
'''
Session-specific PSIJ state.
'''
poll_interval = PSIJ_POLL_INTERVAL
def __init__(self, sid: str, **kwargs: Any):
super().__init__(sid)
self._jobs: Dict[str, Any] = {} # job_id -> psij.Job
self._job_meta: Dict[str, dict] = {} # job_id -> submission metadata
self._job_states: Dict[str, str] = {} # track last known state per job
self._cancelled_jobs: set = set() # job_ids the user asked to cancel
self._poll_interval = kwargs.get('poll_interval', self.poll_interval)
self._poll_task = None
# Persistent output directory for this session's job stdout/stderr
self._output_dir = _OUTPUT_BASE / sid
self._cleanup_stale_output()
self._output_dir.mkdir(parents=True, exist_ok=True)
def _effective_state(self, job_id: str, raw_state: str) -> str:
"""Map a psij state to what the caller should see.
psij's PBS backend only flags Exit_status == 265 as CANCELED
(pbs_base.py:135-139); sites like Aurora return a different code,
so a cancelled job surfaces as COMPLETED/FAILED. We remember
which jobs the user asked to cancel and report those as CANCELED
regardless of what the backend says.
"""
if job_id in self._cancelled_jobs and raw_state in ('COMPLETED', 'FAILED'):
return 'CANCELED'
return raw_state
def _cleanup_stale_output(self) -> None:
"""Remove output directories older than _OUTPUT_MAX_AGE_DAYS."""
if not _OUTPUT_BASE.exists():
return
cutoff = time.time() - _OUTPUT_MAX_AGE_DAYS * 86400
for entry in _OUTPUT_BASE.iterdir():
if not entry.is_dir() or entry == self._output_dir:
continue
try:
if entry.stat().st_mtime < cutoff:
shutil.rmtree(entry)
log.info("Cleaned up stale output dir: %s", entry)
except Exception as e:
log.debug("Failed to clean up %s: %s", entry, e)
[docs]
async def submit_job(self, job_spec_dict: Dict[str, Any], executor_name: str = 'local') -> Dict[str, Any]:
'''
Submit a job via PSIJ.
'''
try:
spec = psij.JobSpec()
executable = job_spec_dict.get('executable')
arguments = job_spec_dict.get('arguments')
spec.executable = executable
if arguments:
spec.arguments = arguments
if 'directory' in job_spec_dict:
spec.directory = job_spec_dict['directory']
if 'environment' in job_spec_dict:
spec.environment = job_spec_dict['environment']
if 'attributes' in job_spec_dict:
attribs = job_spec_dict['attributes']
spec.attributes = psij.JobAttributes()
duration = attribs.get("duration")
if duration:
spec.attributes.duration = timedelta(seconds=int(duration))
spec.attributes.queue_name = attribs.get("queue_name")
spec.attributes.account = attribs.get("account")
spec.attributes.reservation_id = attribs.get("reservation_id")
# Resource spec: pass any ``ResourceSpecV1`` field through
# verbatim (``node_count``, ``exclusive_node_use``, etc).
# An unknown key here raises TypeError -- caller bug, not
# something to silently swallow.
res = job_spec_dict.get('resources') or {}
if res:
spec.resources = psij.ResourceSpecV1(**res)
# Merge site defaults for PSIJ custom_attributes with the
# caller's (caller wins on conflict). Defaults come from the
# detected batch_system backend — e.g. Aurora's PBS requires
# filesystems= on every submission, which the UI / Python API
# users are not expected to know. Only applied when the
# chosen executor matches the detected backend.
from .batch_system import detect_batch_system
backend = detect_batch_system()
defaults = (backend.default_custom_attributes()
if backend.psij_executor == executor_name else {})
caller_ca = dict(job_spec_dict.get('custom_attributes') or {})
merged_ca = {**defaults, **caller_ca}
if merged_ca:
if spec.attributes is None:
spec.attributes = psij.JobAttributes()
spec.attributes.custom_attributes = merged_ca
if defaults:
added = {k: v for k, v in defaults.items()
if k not in caller_ca}
if added:
log.info("[psij] backend=%s injected defaults: %s",
backend.name, added)
job = psij.Job(spec)
out_path = str(self._output_dir / f"{job.id}.out")
err_path = str(self._output_dir / f"{job.id}.err")
spec.stdout_path = out_path
spec.stderr_path = err_path
# ``keep_files=True`` only meaningful for batch-scheduler
# executors (slurm/pbs/lsf/cobalt/...). ``local`` ignores it.
ex_config = None
if _KEEP_PSIJ_FILES and executor_name in ('slurm', 'pbs', 'lsf',
'cobalt', 'flux'):
from psij.executors.batch.batch_scheduler_executor \
import BatchSchedulerExecutorConfig
ex_config = BatchSchedulerExecutorConfig(keep_files=True)
log.info("[psij] RADICAL_ORBIT_PSIJ_KEEP_FILES set: "
"executor=%s keep_files=True", executor_name)
if ex_config is not None:
ex = psij.JobExecutor.get_instance(executor_name,
config=ex_config)
else:
ex = psij.JobExecutor.get_instance(executor_name)
# Set poll interval for status updates
if hasattr(ex, 'poll_interval'):
ex.poll_interval = self._poll_interval
self._jobs[job.id] = job
# Store submission metadata for later retrieval
attribs = job_spec_dict.get('attributes', {})
self._job_meta[job.id] = {
'executable': executable,
'arguments': arguments or [],
'executor': executor_name,
'directory': job_spec_dict.get('directory'),
'queue_name': attribs.get('queue_name'),
'account': attribs.get('account'),
'node_count': attribs.get('node_count'),
'duration': attribs.get('duration'),
}
# Register status callback BEFORE submit so no transitions are missed
plugin = self._plugin
job_id = job.id
last_state = None
def _on_status(j, status):
nonlocal last_state
state_str = _normalize_state(status.state)
state_str = self._effective_state(job_id, state_str)
# Skip if state hasn't changed
if state_str == last_state:
return
last_state = state_str
is_terminal = state_str in TERMINAL_STATES
stdout_content = ""
stderr_content = ""
if is_terminal:
stdout_content = _read_output_file(j, 'stdout_path')
stderr_content = _read_output_file(j, 'stderr_path')
if plugin:
plugin._dispatch_notify("job_status", {
"job_id": job_id,
"state": state_str,
"exit_code": status.exit_code if is_terminal else None,
"stdout": stdout_content,
"stderr": stderr_content
})
job.set_job_status_callback(_on_status)
ex.submit(job)
# Start background polling for job status updates
self._start_polling()
log.info("Submitted job %s to %s", job.id, executor_name)
return {"job_id": job.id, "native_id": job.native_id}
except Exception as e:
log.exception("Job submission failed: %s", e)
raise HTTPException(status_code=500, detail=str(e)) from e
[docs]
async def get_job_status(self, job_id: str,
stdout_offset: int = 0,
stderr_offset: int = 0) -> Dict[str, Any]:
'''
Get job status with metadata and optional stdout/stderr offset.
'''
job = self._jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
status = job.status
state_str = _normalize_state(status.state)
state_str = self._effective_state(job_id, state_str)
stdout_content = _read_output_file(job, 'stdout_path', stdout_offset)
stderr_content = _read_output_file(job, 'stderr_path', stderr_offset)
meta = self._job_meta.get(job_id, {})
return {
"job_id": job_id,
"native_id": job.native_id,
"state": state_str,
"message": status.message,
"exit_code": status.exit_code,
"time": status.time,
"executable": meta.get('executable'),
"arguments": meta.get('arguments', []),
"executor": meta.get('executor'),
"directory": meta.get('directory'),
"queue_name": meta.get('queue_name'),
"account": meta.get('account'),
"node_count": meta.get('node_count'),
"duration": meta.get('duration'),
"stdout": stdout_content,
"stderr": stderr_content,
"stdout_offset": _output_file_size(job, 'stdout_path'),
"stderr_offset": _output_file_size(job, 'stderr_path'),
}
[docs]
async def list_jobs(self) -> Dict[str, Any]:
'''
List all jobs in this session with current state and metadata.
'''
jobs = []
for job_id, job in self._jobs.items():
state_str = _normalize_state(job.status.state)
state_str = self._effective_state(job_id, state_str)
meta = self._job_meta.get(job_id, {})
jobs.append({
"job_id": job_id,
"native_id": job.native_id,
"state": state_str,
"exit_code": job.status.exit_code,
"executable": meta.get('executable'),
"arguments": meta.get('arguments', []),
"executor": meta.get('executor'),
"queue_name": meta.get('queue_name'),
"account": meta.get('account'),
"node_count": meta.get('node_count'),
})
return {"jobs": jobs}
[docs]
async def cancel_job(self, job_id: str) -> Dict[str, Any]:
'''
Cancel a job.
'''
job = self._jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
# Record intent *before* calling cancel() so any status update
# that races with qdel gets mapped through _effective_state.
self._cancelled_jobs.add(job_id)
try:
job.cancel()
return {"job_id": job_id, "status": "canceled"}
except Exception as e:
log.exception("Job cancellation failed: %s", e)
raise HTTPException(status_code=500, detail=str(e)) from e
[docs]
async def close(self) -> dict:
'''
Close the session and stop polling.
'''
if self._poll_task:
self._poll_task.cancel()
try:
await self._poll_task
except asyncio.CancelledError:
pass
self._poll_task = None
# Clean up this session's output directory
if self._output_dir.exists():
try:
shutil.rmtree(self._output_dir)
except Exception as e:
log.debug("Failed to remove output dir %s: %s",
self._output_dir, e)
return await super().close()
def _start_polling(self):
'''
Start the background polling task if not already running.
'''
if self._poll_task is None or self._poll_task.done():
self._poll_task = asyncio.create_task(self._poll_jobs())
async def _poll_jobs(self):
'''
Background task that polls job status and sends notifications.
'''
first = True
while True:
try:
if first:
# Short delay on first poll to catch fast state transitions
await asyncio.sleep(0.5)
first = False
else:
await asyncio.sleep(self._poll_interval)
# Check all non-terminal jobs
for job_id, job in list(self._jobs.items()):
try:
status = job.status
state_str = _normalize_state(status.state)
state_str = self._effective_state(job_id, state_str)
# Skip if state hasn't changed
last_state = self._job_states.get(job_id)
if state_str == last_state:
continue
self._job_states[job_id] = state_str
is_terminal = state_str in TERMINAL_STATES
stdout_content = ""
stderr_content = ""
if is_terminal:
stdout_content = _read_output_file(job, 'stdout_path')
stderr_content = _read_output_file(job, 'stderr_path')
if self._plugin:
self._plugin._dispatch_notify("job_status", {
"job_id": job_id,
"state": state_str,
"exit_code": status.exit_code if is_terminal else None,
"stdout": stdout_content,
"stderr": stderr_content
})
except Exception as e:
log.debug("Error polling job %s: %s", job_id, e)
# Check if all jobs are terminal - if so, stop polling
if all(self._job_states.get(jid) in TERMINAL_STATES
for jid in self._jobs):
break
except asyncio.CancelledError:
break
except Exception as e:
log.debug("Polling error: %s", e)
[docs]
class PSIJClient(PluginClient):
"""
Client-side interface for the PSIJ plugin.
"""
[docs]
def submit_job(self, job_spec: Dict[str, Any], executor: str = 'local') -> Dict[str, Any]:
"""
Submit a job.
Args:
job_spec (dict): The job specification.
executor (str): The executor to use.
Returns:
dict: Job submission result (job_id, native_id).
"""
self._require_session()
url = self._url(f"submit/{self.sid}")
payload = {"job_spec": job_spec, "executor": executor}
resp = self._http.post(url, json=payload)
self._raise(resp, f"psij submit {job_spec.get('executable','?')!r} on {executor!r}")
return resp.json()
[docs]
def get_job_status(self, job_id: str,
stdout_offset: int = 0,
stderr_offset: int = 0) -> Dict[str, Any]:
"""
Get the status of a job.
Args:
job_id: The job ID to query.
stdout_offset: Byte offset for stdout (0 = full).
stderr_offset: Byte offset for stderr (0 = full).
Returns:
Job status info including metadata and stdout/stderr.
"""
self._require_session()
url = self._url(f"status/{self.sid}/{job_id}")
params = {}
if stdout_offset:
params['stdout_offset'] = str(stdout_offset)
if stderr_offset:
params['stderr_offset'] = str(stderr_offset)
resp = self._http.get(url, params=params)
self._raise(resp, f"job status {job_id!r}")
return resp.json()
[docs]
def list_jobs(self) -> Dict[str, Any]:
"""
List all jobs in this session.
Returns:
dict with 'jobs' list.
"""
self._require_session()
resp = self._http.get(self._url(f"list_jobs/{self.sid}"))
self._raise(resp)
return resp.json()
[docs]
def cancel_job(self, job_id: str) -> Dict[str, Any]:
"""
Cancel a job.
Args:
job_id: The job ID to cancel.
Returns:
Cancellation result.
"""
self._require_session()
url = self._url(f"cancel/{self.sid}/{job_id}")
resp = self._http.post(url)
self._raise(resp, f"cancel job {job_id!r}")
return resp.json()
[docs]
def submit_tunneled(self, job_spec: Dict[str, Any],
executor: str = 'local',
tunnel: str = 'none') -> Dict[str, Any]:
"""Submit a job that launches a child Endpoint service on a compute node.
The ``job_spec.arguments`` list *must* contain ``-n <endpoint_name>`` or
``--name <endpoint_name>`` so the child endpoint can register under the
correct name.
Args:
job_spec: PsiJ job specification dict. ``arguments`` must
include ``-n <endpoint_name>``.
executor: PsiJ executor name (default: ``"local"``).
tunnel: SSH tunnel mode for the child's bridge connection.
One of:
* ``'none'`` — child connects directly to the
bridge. No SSH spawned anywhere.
* ``'forward'`` — child opens its own outbound
``ssh -L`` to the login host
(compute → login). Suitable
where outbound SSH from compute
is permitted and login → compute
is blocked (Aurora, Perlmutter).
* ``'reverse'`` — login-side parent opens
``ssh -R`` to the compute host
(login → compute). Suitable
where compute → login SSH is
blocked but login → compute
works (Odo).
Hard-rejects any other value (including ``True`` /
``False``) — there is no boolean back-compat.
Returns:
dict with ``job_id``, ``native_id``, and ``endpoint_name``.
Raises:
ValueError: If *tunnel* is not one of the three string values.
RuntimeError: If the server returns an error response.
"""
if tunnel not in ('none', 'forward', 'reverse'):
raise ValueError(
f"tunnel must be one of 'none' / 'forward' / 'reverse'; "
f"got {tunnel!r}")
self._require_session()
url = self._url(f"submit_tunneled/{self.sid}")
payload = {"job_spec": job_spec, "executor": executor, "tunnel": tunnel}
resp = self._http.post(url, json=payload)
self._raise(resp, f"psij submit_tunneled on {executor!r}")
return resp.json()
[docs]
def tunnel_status(self, endpoint_name: str) -> Dict[str, Any]:
"""Return the current tunnel status for a named endpoint.
This endpoint is session-less (no session required).
Args:
endpoint_name: The logical name of the child endpoint service.
Returns:
dict with fields:
- ``endpoint_name`` — echoed back.
- ``status`` — one of ``"pending"``, ``"active"``, ``"failed"``,
``"done"``, or ``"no_tunnel"``.
- ``port`` — assigned tunnel port (int) once active, else null.
- ``pid`` — SSH process PID, once spawned, else null.
"""
resp = self._http.get(self._url(f"tunnel_status/{endpoint_name}"))
self._raise(resp, f"tunnel_status {endpoint_name!r}")
return resp.json()
[docs]
class PluginPSIJ(Plugin):
'''
PSIJ plugin for ORBIT.
This plugin provides an interface to submit and manage jobs via the
`psij-python` library.
'''
plugin_name = "psij"
session_class = PSIJSession
client_class = PSIJClient
version = '0.0.1'
ui_config = {
"icon": "🚀",
"title": "PsiJ Jobs",
"description": "Submit and monitor HPC batch jobs via PsiJ.",
"forms": [{
"id": "submit",
"title": "📝 Submit Job",
"layout": "grid2",
"fields": [
{"name": "exec", "type": "text", "label": "Executable",
"default": "radical-orbit-endpoint-wrapper.sh", "css_class": "p-exec",
"column": 0},
{"name": "args", "type": "text", "label": "Arguments (space-separated)",
"placeholder": "auto-filled with --url and --name",
"css_class": "p-args", "column": 0},
{"name": "executor", "type": "select", "label": "Executor",
"options": ["local", "slurm", "pbs", "lsf"],
"css_class": "p-executor", "column": 0},
{"name": "queue", "type": "text", "label": "Queue / Partition",
"placeholder": "optional", "required": False,
"css_class": "p-queue", "column": 1},
{"name": "account", "type": "text", "label": "Account / Project",
"placeholder": "optional", "required": False,
"css_class": "p-account", "column": 1},
{"name": "duration", "type": "text", "label": "Duration (seconds)",
"placeholder": "e.g. 600", "required": False,
"css_class": "p-duration", "column": 1},
{"name": "node_count", "type": "number", "label": "Number of Nodes",
"placeholder": "e.g. 1", "required": False,
"css_class": "p-node-count", "column": 1},
{"name": "custom", "type": "custom_attributes", "label": "🔧 Custom Attributes",
"required": False, "css_class": "p-custom-attr", "column": 1},
],
"submit": {"label": "🚀 Submit Job", "style": "success"}
}],
"monitors": [{
"id": "jobs",
"title": "📊 Job Monitor",
"type": "task_list",
"css_class": "psij-output",
"empty_text": "No jobs submitted yet."
}],
"notifications": {
"topic": "job_status",
"id_field": "job_id",
"state_field": "state"
}
}
[docs]
@classmethod
def is_enabled(cls, app: FastAPI) -> bool:
"""PsiJ loads on endpoint nodes (login or compute) — not on the bridge."""
return not getattr(app.state, 'is_bridge', False)
def __init__(self, app: FastAPI, instance_name: str = "psij"):
super().__init__(app, instance_name)
# watcher tasks keyed by endpoint_name (plugin-level, survive session cleanup)
self._watchers: dict = {}
# Reverse-tunnel SSH processes keyed by endpoint_name (parent side
# only — forward-mode tunnels live in the child process and are
# invisible from here).
self._tunnel_procs: dict = {}
# job_id -> error message for jobs we cancelled because their
# tunnel setup failed. Read by ``get_job_status`` to override
# the underlying CANCELLED state to FAILED with context. An
# entry is overwritten on next cancel for the same job_id; we
# never pop on read (repeated reads return a stable result).
self._failure_reasons: dict = {}
# Ensure relay directory exists at startup
_relay_dir()
self._app.router.on_shutdown.append(self._cleanup_watchers)
self.add_route_post('submit/{sid}', self.submit_job)
self.add_route_post('submit_tunneled/{sid}', self.submit_tunneled)
self.add_route_get('tunnel_status/{endpoint_name}', self.tunnel_status)
self.add_route_get('status/{sid}/{job_id}', self.get_job_status)
self.add_route_get('list_jobs/{sid}', self.list_jobs)
self.add_route_post('cancel/{sid}/{job_id}', self.cancel_job)
[docs]
async def submit_job(self, request: Request) -> dict:
sid = request.path_params['sid']
data = await request.json()
job_spec = data.get('job_spec', {})
executor = data.get('executor', 'local')
return await self._forward(sid, PSIJSession.submit_job,
job_spec_dict=job_spec,
executor_name=executor)
[docs]
async def get_job_status(self, request: Request) -> dict:
sid = request.path_params['sid']
job_id = request.path_params['job_id']
so = int(request.query_params.get('stdout_offset', '0'))
se = int(request.query_params.get('stderr_offset', '0'))
status = await self._forward(sid, PSIJSession.get_job_status,
job_id=job_id,
stdout_offset=so,
stderr_offset=se)
# If we cancelled this job because its tunnel setup failed,
# override the underlying CANCELLED state with FAILED + the
# actual reason. Operator-initiated cancels (no entry in
# ``_failure_reasons``) keep their natural CANCELLED state.
err = self._failure_reasons.get(job_id)
if err:
status['state'] = 'FAILED'
status['error'] = err
return status
[docs]
async def list_jobs(self, request: Request) -> dict:
sid = request.path_params['sid']
return await self._forward(sid, PSIJSession.list_jobs)
[docs]
async def cancel_job(self, request: Request) -> dict:
sid = request.path_params['sid']
job_id = request.path_params['job_id']
return await self._forward(sid, PSIJSession.cancel_job, job_id=job_id)
# ─────────────────────────────────────────────────────────────────────────
# Endpoint-job submission with optional reverse SSH tunnel
# ─────────────────────────────────────────────────────────────────────────
[docs]
async def submit_tunneled(self, request: Request) -> dict:
"""Submit a job that starts a new Endpoint service on a compute node.
The job *must* pass ``-n``/``--name <endpoint_name>`` in its arguments so
the child endpoint service can register under the correct name.
Tunnel direction is selected by the ``tunnel`` field:
* ``'none'`` — no SSH tunnel; child connects directly to the bridge.
* ``'forward'`` — child opens its own outbound ``ssh -L`` back to
this login node (compute → login). We inject
``--tunnel forward`` and ``--tunnel-via <login>``
into the child's argv. The child writes the
rendezvous file itself; the parent watcher
only observes job state.
* ``'reverse'`` — *parent* (this plugin) opens ``ssh -R`` to the
compute node once the job reaches RUNNING and
writes the rendezvous file with the remote
port allocated by the compute-side sshd. We
inject only ``--tunnel reverse`` so the child
waits for the rendezvous file.
Request body JSON fields:
- ``job_spec`` (dict) — PsiJ job specification.
- ``executor`` (str) — PsiJ executor name (default: ``"local"``).
- ``tunnel`` (str) — One of ``'none'``, ``'forward'``, ``'reverse'``
(default: ``'none'``). Boolean values
are *not* accepted.
Returns:
JSON with ``job_id``, ``native_id``, and ``endpoint_name``.
Raises:
400 if ``tunnel`` is not one of the three string values.
422 if ``-n``/``--name`` is missing from ``job_spec.arguments``.
409 if a tunnel watcher for the same endpoint name is already active.
"""
sid = request.path_params['sid']
data = await request.json()
job_spec = data.get('job_spec', {})
executor = data.get('executor', 'local')
tunnel = data.get('tunnel', 'none')
if tunnel not in ('none', 'forward', 'reverse'):
raise HTTPException(
status_code=400,
detail=f"tunnel must be one of 'none' / 'forward' / 'reverse'; "
f"got {tunnel!r}")
# --- resolve endpoint name from arguments ---
args = list(job_spec.get('arguments') or [])
endpoint_name = None
for i, a in enumerate(args[:-1]):
if a in ('-n', '--name'):
endpoint_name = args[i + 1]
break
if not endpoint_name:
raise HTTPException(
status_code=422,
detail="submit_tunneled requires -n/--name <endpoint_name> in job_spec.arguments")
# --- guard against duplicate watchers ---
existing = self._watchers.get(endpoint_name)
if existing and not existing.done():
raise HTTPException(
status_code=409,
detail=f"Tunnel watcher already active for endpoint '{endpoint_name}'")
# --- prepare rendezvous + inject child-side flags ---
relay_file: 'pathlib.Path | None' = None
if tunnel != 'none':
relay_file = _relay_dir() / f'{endpoint_name}.port'
relay_file.unlink(missing_ok=True) # remove stale file from previous run
pid_file = _relay_dir() / f'{endpoint_name}.pid'
pid_file.unlink(missing_ok=True)
req_file = _relay_dir() / f'{endpoint_name}.req'
req_file.unlink(missing_ok=True)
if '--tunnel' not in args:
args.extend(['--tunnel', tunnel])
if tunnel == 'forward' and '--tunnel-via' not in args:
# Forward mode: child needs to know which login host to ssh to.
args.extend(['--tunnel-via', socket.gethostname()])
job_spec = dict(job_spec)
job_spec['arguments'] = args
resp = await self._forward(sid, PSIJSession.submit_job,
job_spec_dict=job_spec,
executor_name=executor)
if tunnel != 'none' and relay_file is not None:
native_id = resp.get('native_id')
job_id = resp.get('job_id')
log.info("[psij] submit_tunneled mode=%s: endpoint=%s job_id=%s "
"native_id=%s -- watcher started",
tunnel, endpoint_name, job_id, native_id)
task = asyncio.create_task(
self._tunnel_watcher(endpoint_name, native_id, job_id,
relay_file, tunnel))
self._watchers[endpoint_name] = task
# Augment response with endpoint_name for caller convenience
resp['endpoint_name'] = endpoint_name
return resp
[docs]
async def tunnel_status(self, request: Request) -> dict:
"""Return the current tunnel status for a named endpoint.
Path param: ``endpoint_name``
Returns a JSON object with fields:
- ``endpoint_name`` — echoed back.
- ``status`` — one of ``"pending"``, ``"active"``, ``"failed"``,
``"done"``, or ``"no_tunnel"``.
- ``port`` — allocated tunnel port (int) once the child endpoint
has published it, else null.
- ``pid`` — SSH process PID on the compute node (read from
the pid rendezvous file) once active, else null.
"""
endpoint_name = request.path_params['endpoint_name']
relay_file = _relay_dir() / f'{endpoint_name}.port'
pid_file = _relay_dir() / f'{endpoint_name}.pid'
port = None
pid = None
if relay_file.exists():
try:
port = int(relay_file.read_text().strip())
except (ValueError, OSError):
pass
if pid_file.exists():
try:
pid = int(pid_file.read_text().strip())
except (ValueError, OSError):
pass
task = self._watchers.get(endpoint_name)
if task is None:
status = 'no_tunnel'
elif port is not None:
# Relay file present → child endpoint successfully published its port.
# The SSH process lives on the compute node and is not observable
# from here, so ``active`` is terminal from the login's point of
# view.
status = 'active'
elif task.done():
# Watcher finished without a port file → the job terminated or
# the child never published. Report as failed.
status = 'failed'
else:
# Watcher still running, waiting for the child to publish.
status = 'pending'
return {'endpoint_name': endpoint_name,
'status': status,
'port': port,
'pid': pid}
# ─────────────────────────────────────────────────────────────────────────
# Internal tunnel helpers
# ─────────────────────────────────────────────────────────────────────────
async def _tunnel_watcher(self, endpoint_name: str, native_id,
job_id: 'str | None',
relay_file: 'pathlib.Path',
mode: str) -> None:
"""Watch a tunneled-job's progress; behaviour depends on *mode*.
**forward** (compute → login): the child opens its own ``ssh -L``
and writes the rendezvous file. This watcher only observes the
job state. If the job goes terminal before the file appears,
the failure already manifests as the job's natural ``FAILED``
state — we do nothing (no parent-side cancel needed).
**reverse** (login → compute): once the job reaches ``RUNNING``
we look up the compute hostname via ``BatchSystem.job_nodes()``
and spawn ``ssh -R`` from this side. On any spawn failure (or
the SSH process dying before/after writing the rendezvous file)
we record a reason in ``_failure_reasons[job_id]`` and call
``cancel_job`` so the now-useless allocation is released; the
client then sees the cancel as ``FAILED`` (with our reason) via
:meth:`get_job_status`.
Args:
endpoint_name: Logical name of the child endpoint service.
native_id: Native scheduler job ID (SLURM/PBS/...).
job_id: PsiJ job-id (key in ``_failure_reasons``).
relay_file: Shared-filesystem file the child reads
regardless of who writes it.
mode: ``'forward'`` or ``'reverse'``. ``'none'``
callers don't reach here.
"""
from .batch_system import (detect_batch_system, STATE_CANCELLED,
STATE_UNKNOWN, TERMINAL_STATES)
from . import tunnel as _tunnel
batch = detect_batch_system()
log.info("[psij] Watcher started mode=%s for endpoint '%s' "
"(job=%s native=%s, backend=%s) — relay file %s",
mode, endpoint_name, job_id, native_id, batch.name, relay_file)
# In reverse mode this watcher *will* spawn an SSH process and
# is responsible for tearing it down.
ssh_proc = None
# Bridge URL/port for the reverse spawn — same value the child
# would resolve, so we can hand it to OpenSSH's -R spec.
bridge_host = 'localhost'
bridge_port = 8000
if mode == 'reverse':
from urllib.parse import urlparse
bridge_url = getattr(self._app.state, 'bridge_url', '') or ''
parsed = urlparse(bridge_url)
bridge_host = parsed.hostname or 'localhost'
bridge_port = parsed.port or (443 if parsed.scheme in ('https', 'wss') else 8000)
last_state = None
seen_known = False
unknown_streak = 0
try:
for attempt in range(300): # up to ~10 min (2s × 300)
await asyncio.sleep(2)
# Both modes: rendezvous file appearing is the success signal.
if relay_file.exists():
try:
port = int(relay_file.read_text().strip())
except (ValueError, OSError):
port = None
log.info("[psij] endpoint '%s' tunnel active on port %s "
"(mode=%s)", endpoint_name, port, mode)
if mode == 'reverse':
# Continue polling so we can tear ssh_proc down
# cleanly when the job ends.
await self._await_reverse_teardown(
endpoint_name, native_id, ssh_proc, batch)
return
# Reverse-mode side-channel: spawn ssh -R as soon as
# the job has been allocated a compute host.
state = await asyncio.to_thread(batch.job_state, native_id)
# Reverse-mode spawn: gate on the child's .req file
# ONLY. We deliberately do NOT require state ==
# RUNNING: SLURM state polling is unreliable on some
# clusters (squeue can return UNKNOWN/empty for the
# entire lifetime of a short job; observed on ODO
# 2026-05-11) and is a stale-at-best proxy for
# "child is ready". The .req file is authoritative
# — it can only have been produced by the running
# child, on its own compute node, after
# socket.gethostname() returned. State polling
# below still aborts the watcher when the job
# reaches a TERMINAL state without producing .req.
if mode == 'reverse' and ssh_proc is None:
req_file = relay_file.with_suffix('.req')
# NFSv3 caches negative lookups (file-doesn't-exist)
# for tens of seconds. After the parent's first
# `req_file.exists()` returns False, the cached
# ENOENT keeps returning False even after the
# child writes .req on the shared FS — observed
# on ODO 2026-05-11 17:10: .req appeared at +21s,
# parent kept seeing False for the whole 16s
# window the file was on disk. A readdir on the
# parent dir invalidates the negative-lookup
# cache by forcing fresh directory attributes.
try:
dir_contents = set(
os.listdir(str(req_file.parent)))
except OSError:
dir_contents = set()
if req_file.name in dir_contents:
try:
import json as _json
compute_host = _json.loads(
req_file.read_text()).get('hostname')
except (ValueError, OSError) as exc:
await self._fail_tunnel(
endpoint_name, job_id, native_id,
f"reverse SSH: .req file unreadable: {exc}")
return
if not compute_host:
await self._fail_tunnel(
endpoint_name, job_id, native_id,
"reverse SSH: .req file has no 'hostname' field")
return
log.info("[psij] reverse: child .req says hostname=%s "
"for job %s, spawning ssh -R to %s:%s",
compute_host, native_id, bridge_host, bridge_port)
# Retry the spawn for up to ~30s. Some
# sites' compute-node sshd refuses logins
# from the login node for a short window
# after the job is registered (e.g.
# pam_slurm_adopt rejects until the job's
# cgroup is fully established). The retry
# is transport-agnostic: any spawn failure
# gets a fresh attempt 1s later. Bail out
# immediately if the job goes terminal in
# the meantime — no point retrying once the
# allocation is gone.
last_exc = None
for spawn_attempt in range(30):
try:
ssh_proc, port = await asyncio.to_thread(
_tunnel.spawn_reverse_tunnel,
compute_host, bridge_host, bridge_port,
endpoint_name)
break
except Exception as exc:
last_exc = exc
if (await asyncio.to_thread(
batch.job_state, native_id)
in TERMINAL_STATES):
await self._fail_tunnel(
endpoint_name, job_id, native_id,
f"reverse SSH spawn failed and job "
f"went terminal: {exc}")
return
if spawn_attempt == 0:
log.info("[psij] reverse: first ssh -R "
"spawn rejected (likely "
"pam_slurm_adopt race), "
"retrying: %s", exc)
else:
log.debug("[psij] reverse: ssh -R spawn "
"retry %d/30: %s",
spawn_attempt + 1, exc)
await asyncio.sleep(1)
else:
await self._fail_tunnel(
endpoint_name, job_id, native_id,
f"reverse SSH spawn failed after 30 "
f"attempts: {last_exc}")
return
self._tunnel_procs[endpoint_name] = ssh_proc
if state == STATE_UNKNOWN:
unknown_streak += 1
else:
seen_known = True
unknown_streak = 0
if state != last_state or attempt % 30 == 0:
log.info("[psij] watcher endpoint=%s job=%s mode=%s "
"state=%r (attempt %d/300)",
endpoint_name, native_id, mode,
state or '(unknown)', attempt)
last_state = state
if state in TERMINAL_STATES:
log.warning("[psij] Job %s ended with state %s — "
"aborting watch (relay file %s never appeared)",
native_id, state, relay_file)
if mode == 'reverse' and ssh_proc is not None:
# We had spawned SSH but the rendezvous file never
# appeared. Treat as tunnel failure.
await self._fail_tunnel(
endpoint_name, job_id, native_id,
f"reverse SSH spawned but rendezvous file never "
f"appeared (job {state})", spawn_proc=ssh_proc)
elif mode == 'reverse' and state != STATE_CANCELLED:
# Job hit a terminal state before the child
# could write .req — surface that as a tunnel
# failure so the client gets a clear error.
# CANCELLED is left alone: that's a deliberate
# operator-initiated outcome and shouldn't be
# converted to FAILED via _fail_tunnel.
await self._fail_tunnel(
endpoint_name, job_id, native_id,
f"child never wrote .req (job ended {state} "
f"before reverse tunnel could be set up)")
return
if seen_known and unknown_streak >= UNKNOWN_TOLERANCE:
log.warning("[psij] Job %s vanished from queue "
"(state=UNKNOWN x %d) — aborting watch "
"(relay file %s never appeared)",
native_id, unknown_streak, relay_file)
if mode == 'reverse' and ssh_proc is not None:
await self._fail_tunnel(
endpoint_name, job_id, native_id,
f"reverse SSH spawned but job vanished "
f"(UNKNOWN x {unknown_streak})",
spawn_proc=ssh_proc)
return
log.warning("[psij] Watcher for endpoint '%s' timed out waiting for "
"tunnel port file %s", endpoint_name, relay_file)
if mode == 'reverse':
await self._fail_tunnel(
endpoint_name, job_id, native_id,
"tunnel watcher timed out before rendezvous file appeared",
spawn_proc=ssh_proc)
finally:
if ssh_proc is not None and ssh_proc.poll() is None:
_tunnel.cleanup_tunnel(ssh_proc, endpoint_name)
self._tunnel_procs.pop(endpoint_name, None)
async def _await_reverse_teardown(self, endpoint_name: str, native_id,
ssh_proc, batch) -> None:
"""Once a reverse tunnel is active, poll the job state until
it reaches a terminal state, then tear down the SSH process."""
from .batch_system import TERMINAL_STATES, STATE_UNKNOWN
from . import tunnel as _tunnel
try:
while True:
await asyncio.sleep(5)
state = await asyncio.to_thread(batch.job_state, native_id)
if state in TERMINAL_STATES or state == STATE_UNKNOWN:
log.info("[psij] reverse: job %s reached %s — "
"tearing down ssh -R for endpoint %s",
native_id, state, endpoint_name)
return
if ssh_proc.poll() is not None:
log.warning("[psij] reverse: ssh -R for endpoint %s exited "
"(rc=%s) while job %s still running",
endpoint_name, ssh_proc.returncode, native_id)
return
finally:
_tunnel.cleanup_tunnel(ssh_proc, endpoint_name)
self._tunnel_procs.pop(endpoint_name, None)
(_relay_dir() / f'{endpoint_name}.req').unlink(missing_ok=True)
async def _fail_tunnel(self, endpoint_name: str, job_id: 'str | None',
native_id, reason: str, spawn_proc=None) -> None:
"""Record a tunnel failure and cancel the now-useless job.
Recorded reason surfaces via ``get_job_status`` as a synthesised
``state='FAILED'`` plus an ``error`` field — see the override
in :meth:`get_job_status`.
"""
from . import tunnel as _tunnel
log.error("[psij] tunnel failed for endpoint '%s' (job %s): %s",
endpoint_name, job_id, reason)
if job_id:
self._failure_reasons[job_id] = reason
if spawn_proc is not None:
_tunnel.cleanup_tunnel(spawn_proc, endpoint_name)
self._tunnel_procs.pop(endpoint_name, None)
# Clean up the child's .req rendezvous file on every failure
# path, not just when an ssh proc was spawned. Otherwise an
# UNKNOWN-streak / terminal-state abort that fires before we
# ever entered the spawn branch leaves stale .req on disk,
# which the next watcher run might pick up after submit-time
# cleanup if the user re-submits very quickly.
(_relay_dir() / f'{endpoint_name}.req').unlink(missing_ok=True)
if job_id is not None:
try:
# Use the underlying PSIJSession.cancel_job to release the
# allocation. Fire-and-forget — the watcher has already
# failed-marked the job.
await self._dispatch_cancel(str(job_id))
except Exception as exc:
log.warning("[psij] cancel after tunnel failure raised: %s",
exc)
async def _dispatch_cancel(self, job_id: str) -> None:
"""Cancel a job by id from inside a watcher.
We can't call the HTTP route directly (we're not in a request
handler), so we walk the live PSIJSession instances looking for
the one that submitted *job_id*, and call its ``cancel_job``
directly.
"""
for session in list(self._sessions.values()):
if not isinstance(session, PSIJSession):
continue
if job_id in getattr(session, '_jobs', {}):
await session.cancel_job(job_id)
return
log.warning("[psij] _dispatch_cancel: no session owns job %s", job_id)
async def _cleanup_watchers(self) -> None:
"""Cancel all watcher tasks + tear down any open reverse SSH
processes on plugin shutdown."""
from . import tunnel as _tunnel
for _, task in list(self._watchers.items()):
task.cancel()
self._watchers.clear()
for endpoint_name, proc in list(self._tunnel_procs.items()):
_tunnel.cleanup_tunnel(proc, endpoint_name)
(_relay_dir() / f'{endpoint_name}.req').unlink(missing_ok=True)
self._tunnel_procs.clear()