Source code for experiments.harness.sharding
"""SLURM Job Array sharding and shard-merge utilities.
Enables distributing experiment trials across multiple SLURM array
tasks. Each task deterministically regenerates the full spec list,
takes its contiguous slice via :func:`shard_specs`, and saves a
partial ``.pb`` file. After all tasks complete, :func:`merge_shard_files`
combines the shards into a single result file.
"""
import sys
import time
from pathlib import Path
from typing import TYPE_CHECKING
from experiments.decode import _guess_experiment, _RESULT_TYPES
if TYPE_CHECKING:
from experiments.harness.worker import TrialSpec
[docs]
def shard_specs(
specs: "list[TrialSpec]",
shard_index: int,
num_shards: int,
) -> "list[TrialSpec]":
"""Return the contiguous slice of *specs* assigned to this shard.
Uses ``divmod`` chunking so the first ``remainder`` shards each get
one extra spec, ensuring no spec is missed or duplicated.
Parameters
----------
specs : list[TrialSpec]
The full, deterministically generated spec list.
shard_index : int
0-based index of this shard.
num_shards : int
Total number of shards.
Returns
-------
list[TrialSpec]
The slice of specs for this shard (may be empty if
``num_shards > len(specs)``).
"""
if num_shards <= 0:
raise ValueError(f"num_shards must be >= 1, got {num_shards}")
if not (0 <= shard_index < num_shards):
raise ValueError(
f"shard_index must be in [0, {num_shards}), got {shard_index}"
)
# Round-robin assignment: shard k gets specs k, k+K, k+2K, ...
# This interleaves specs across shards, balancing work when specs
# are ordered by cost (e.g. ascending n).
return specs[shard_index::num_shards]
def shard_output_path(base_path: str, shard_index: int, num_shards: int) -> str:
"""Append a shard suffix to a ``.pb`` output path.
``"results/scaling_4_10_20.pb"`` becomes
``"results/scaling_4_10_20_shard3of8.pb"`` for shard 3 of 8.
"""
p = Path(base_path)
return str(p.with_stem(f"{p.stem}_shard{shard_index + 1}of{num_shards}"))
[docs]
def merge_shard_files(shard_paths: list[Path], output_path: Path) -> None:
"""Merge sharded ``.pb`` experiment results into a single file.
Reads each shard, concatenates all trial results, aggregates
metadata (summed wall-clock time, total trial count), and writes
the merged protobuf to *output_path*.
Parameters
----------
shard_paths : list[Path]
Paths to shard ``.pb`` files. All must be from the same
experiment type.
output_path : Path
Destination for the merged ``.pb`` file.
"""
if not shard_paths:
print("Error: no shard files provided", file=sys.stderr)
sys.exit(1)
# Parse first shard to determine experiment type
experiment_name = _guess_experiment(shard_paths[0])
cls = _RESULT_TYPES[experiment_name]
merged_trials = []
total_wall_clock = 0.0
max_workers = 0
parameters_msg = None
for path in shard_paths:
if not path.exists():
print(f"Warning: shard file not found, skipping: {path}", file=sys.stderr)
continue
shard_name = _guess_experiment(path)
if shard_name != experiment_name:
print(
f"Error: mixed experiment types: {experiment_name} vs {shard_name} "
f"(file: {path})",
file=sys.stderr,
)
sys.exit(1)
msg = cls()
msg.ParseFromString(path.read_bytes())
merged_trials.extend(msg.trials)
total_wall_clock += msg.metadata.wall_clock_s
max_workers = max(max_workers, msg.metadata.max_workers)
if parameters_msg is None:
parameters_msg = msg.parameters
# Build merged message
from experiments.proto import common_pb2
merged_metadata = common_pb2.ExperimentMetadata(
experiment_name=experiment_name,
timestamp=time.strftime("%Y-%m-%dT%H:%M:%S"),
wall_clock_s=total_wall_clock,
max_workers=max_workers,
num_trials=len(merged_trials),
)
merged_msg = cls(
metadata=merged_metadata,
parameters=parameters_msg,
trials=merged_trials,
)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "wb") as f:
f.write(merged_msg.SerializeToString())
print(
f"Merged {len(shard_paths)} shards -> {output_path} "
f"({len(merged_trials)} trials, {total_wall_clock:.1f}s total wall-clock)"
)