feat(time-based-sharding): updates plugin

This commit is contained in:
nuno maduro
2026-04-14 08:34:41 -07:00
parent 13c322bab3
commit 4ac14b2528
7 changed files with 455 additions and 3 deletions

View File

@ -123,6 +123,10 @@ final readonly class Help implements HandlesArguments
'arg' => '--update-snapshots', 'arg' => '--update-snapshots',
'desc' => 'Update snapshots for tests using the "toMatchSnapshot" expectation', 'desc' => 'Update snapshots for tests using the "toMatchSnapshot" expectation',
], ],
[
'arg' => '--update-shards',
'desc' => 'Update shards.json with test timing data for time-balanced sharding',
],
], ...$content['Execution']]; ], ...$content['Execution']];
$content['Selection'] = [[ $content['Selection'] = [[

View File

@ -6,7 +6,13 @@ namespace Pest\Plugins;
use Pest\Contracts\Plugins\AddsOutput; use Pest\Contracts\Plugins\AddsOutput;
use Pest\Contracts\Plugins\HandlesArguments; use Pest\Contracts\Plugins\HandlesArguments;
use Pest\Contracts\Plugins\Terminable;
use Pest\Exceptions\InvalidOption; use Pest\Exceptions\InvalidOption;
use Pest\Subscribers\EnsureShardTimingFinished;
use Pest\Subscribers\EnsureShardTimingStarted;
use Pest\Subscribers\EnsureShardTimingsAreCollected;
use Pest\TestSuite;
use PHPUnit\Event;
use Symfony\Component\Console\Input\ArgvInput; use Symfony\Component\Console\Input\ArgvInput;
use Symfony\Component\Console\Input\InputInterface; use Symfony\Component\Console\Input\InputInterface;
use Symfony\Component\Console\Output\OutputInterface; use Symfony\Component\Console\Output\OutputInterface;
@ -15,7 +21,7 @@ use Symfony\Component\Process\Process;
/** /**
* @internal * @internal
*/ */
final class Shard implements AddsOutput, HandlesArguments final class Shard implements AddsOutput, HandlesArguments, Terminable
{ {
use Concerns\HandleArguments; use Concerns\HandleArguments;
@ -33,6 +39,30 @@ final class Shard implements AddsOutput, HandlesArguments
*/ */
private static ?array $shard = null; private static ?array $shard = null;
/**
* Whether to update the shards.json file.
*/
private static bool $updateShards = false;
/**
* Whether time-balanced sharding was used.
*/
private static bool $timeBalanced = false;
/**
* Collected timings from workers or subscribers.
*
* @var array<string, float>|null
*/
private static ?array $collectedTimings = null;
/**
* The canonical list of test classes from --list-tests.
*
* @var list<string>|null
*/
private static ?array $knownTests = null;
/** /**
* Creates a new Plugin instance. * Creates a new Plugin instance.
*/ */
@ -47,6 +77,19 @@ final class Shard implements AddsOutput, HandlesArguments
*/ */
public function handleArguments(array $arguments): array public function handleArguments(array $arguments): array
{ {
if ($this->hasArgument('--update-shards', $arguments)) {
return $this->handleUpdateShards($arguments);
}
if (Parallel::isWorker() && Parallel::getGlobal('UPDATE_SHARDS') === true) {
self::$updateShards = true;
Event\Facade::instance()->registerSubscriber(new EnsureShardTimingStarted);
Event\Facade::instance()->registerSubscriber(new EnsureShardTimingFinished);
return $arguments;
}
if (! $this->hasArgument('--shard', $arguments)) { if (! $this->hasArgument('--shard', $arguments)) {
return $arguments; return $arguments;
} }
@ -63,7 +106,21 @@ final class Shard implements AddsOutput, HandlesArguments
/** @phpstan-ignore-next-line */ /** @phpstan-ignore-next-line */
$tests = $this->allTests($arguments); $tests = $this->allTests($arguments);
$timings = self::loadShardsFile();
if ($timings !== null) {
$missingTests = array_diff($tests, array_keys($timings));
if ($missingTests !== []) {
throw new InvalidOption('The [tests/.pest/shards.json] file is out of date. Run [--update-shards] to update it.');
}
$partitions = self::partitionByTime($tests, $timings, $total);
$testsToRun = $partitions[$index - 1] ?? [];
self::$timeBalanced = true;
} else {
$testsToRun = (array_chunk($tests, max(1, (int) ceil(count($tests) / $total))))[$index - 1] ?? []; $testsToRun = (array_chunk($tests, max(1, (int) ceil(count($tests) / $total))))[$index - 1] ?? [];
}
self::$shard = [ self::$shard = [
'index' => $index, 'index' => $index,
@ -75,6 +132,36 @@ final class Shard implements AddsOutput, HandlesArguments
return [...$arguments, '--filter', $this->buildFilterArgument($testsToRun)]; return [...$arguments, '--filter', $this->buildFilterArgument($testsToRun)];
} }
/**
* Handles the --update-shards argument.
*
* @param array<int, string> $arguments
* @return array<int, string>
*/
private function handleUpdateShards(array $arguments): array
{
if ($this->hasArgument('--shard', $arguments)) {
throw new InvalidOption('The [--update-shards] option cannot be combined with [--shard].');
}
$arguments = $this->popArgument('--update-shards', $arguments);
self::$updateShards = true;
/** @phpstan-ignore-next-line */
self::$knownTests = $this->allTests($arguments);
if ($this->hasArgument('--parallel', $arguments) || $this->hasArgument('-p', $arguments)) {
Parallel::setGlobal('UPDATE_SHARDS', true);
Parallel::setGlobal('SHARD_RUN_ID', uniqid('pest-shard-', true));
} else {
Event\Facade::instance()->registerSubscriber(new EnsureShardTimingStarted);
Event\Facade::instance()->registerSubscriber(new EnsureShardTimingFinished);
}
return $arguments;
}
/** /**
* Returns all tests that the test suite would run. * Returns all tests that the test suite would run.
* *
@ -116,6 +203,20 @@ final class Shard implements AddsOutput, HandlesArguments
*/ */
public function addOutput(int $exitCode): int public function addOutput(int $exitCode): int
{ {
if (self::$updateShards && ! Parallel::isWorker()) {
self::$collectedTimings = self::collectTimings();
$count = self::$knownTests !== null
? count(array_intersect_key(self::$collectedTimings, array_flip(self::$knownTests)))
: count(self::$collectedTimings);
$this->output->writeln(sprintf(
' <fg=gray>Shards:</> <fg=default>shards.json updated with timings for %d test class%s.</>',
$count,
$count === 1 ? '' : 'es',
));
}
if (self::$shard === null) { if (self::$shard === null) {
return $exitCode; return $exitCode;
} }
@ -128,17 +229,243 @@ final class Shard implements AddsOutput, HandlesArguments
] = self::$shard; ] = self::$shard;
$this->output->writeln(sprintf( $this->output->writeln(sprintf(
' <fg=gray>Shard:</> <fg=default>%d of %d</> — %d file%s ran, out of %d.', ' <fg=gray>Shard:</> <fg=default>%d of %d</> — %d file%s ran, out of %d%s.',
$index, $index,
$total, $total,
$testsRan, $testsRan,
$testsRan === 1 ? '' : 's', $testsRan === 1 ? '' : 's',
$testsCount, $testsCount,
self::$timeBalanced ? ' <fg=gray>(time-balanced)</>' : '',
)); ));
return $exitCode; return $exitCode;
} }
/**
* Terminates the plugin.
*/
public function terminate(): void
{
if (! self::$updateShards) {
return;
}
if (Parallel::isWorker()) {
self::writeWorkerTimings();
return;
}
$timings = self::$collectedTimings ?? self::collectTimings();
if ($timings === []) {
return;
}
self::writeTimings($timings);
}
/**
* Collects timings from subscribers or worker temp files.
*
* @return array<string, float>
*/
private static function collectTimings(): array
{
$runId = Parallel::getGlobal('SHARD_RUN_ID');
if (is_string($runId)) {
return self::readWorkerTimings($runId);
}
return EnsureShardTimingsAreCollected::timings();
}
/**
* Writes the current worker's timing data to a temp file.
*/
private static function writeWorkerTimings(): void
{
$timings = EnsureShardTimingsAreCollected::timings();
if ($timings === []) {
return;
}
$runId = Parallel::getGlobal('SHARD_RUN_ID');
if (! is_string($runId)) {
return;
}
$path = sys_get_temp_dir().DIRECTORY_SEPARATOR.$runId.'-'.getmypid().'.json';
file_put_contents($path, json_encode($timings, JSON_THROW_ON_ERROR));
}
/**
* Reads and merges timing data from all worker temp files.
*
* @return array<string, float>
*/
private static function readWorkerTimings(string $runId): array
{
$pattern = sys_get_temp_dir().DIRECTORY_SEPARATOR.$runId.'-*.json';
$files = glob($pattern);
if ($files === false || $files === []) {
return [];
}
$merged = [];
foreach ($files as $file) {
$contents = file_get_contents($file);
if ($contents === false) {
continue;
}
$timings = json_decode($contents, true);
if (is_array($timings)) {
$merged = array_merge($merged, $timings);
}
unlink($file);
}
return $merged;
}
/**
* Returns the path to shards.json.
*/
private static function shardsPath(): string
{
$testSuite = TestSuite::getInstance();
return implode(DIRECTORY_SEPARATOR, [$testSuite->rootPath, $testSuite->testPath, '.pest', 'shards.json']);
}
/**
* Loads the timings from shards.json.
*
* @return array<string, float>|null
*/
private static function loadShardsFile(): ?array
{
$path = self::shardsPath();
if (! file_exists($path)) {
return null;
}
$contents = file_get_contents($path);
if ($contents === false) {
return null;
}
$data = json_decode($contents, true);
if (! is_array($data) || ! isset($data['timings']) || ! is_array($data['timings'])) {
return null;
}
/** @var array<string, float> */
return $data['timings'];
}
/**
* Partitions tests across shards using the LPT (Longest Processing Time) algorithm.
*
* @param list<string> $tests
* @param array<string, float> $timings
* @return list<list<string>>
*/
private static function partitionByTime(array $tests, array $timings, int $total): array
{
$knownTimings = array_filter(
array_map(fn (string $test): ?float => $timings[$test] ?? null, $tests),
fn (?float $t): bool => $t !== null,
);
$median = $knownTimings !== [] ? self::median(array_values($knownTimings)) : 1.0;
$testsWithTimings = array_map(
fn (string $test): array => ['test' => $test, 'time' => $timings[$test] ?? $median],
$tests,
);
usort($testsWithTimings, fn (array $a, array $b): int => $b['time'] <=> $a['time']);
/** @var list<list<string>> */
$bins = array_fill(0, $total, []);
/** @var non-empty-list<float> */
$binTimes = array_fill(0, $total, 0.0);
foreach ($testsWithTimings as $item) {
$minIndex = array_search(min($binTimes), $binTimes, strict: true);
assert(is_int($minIndex));
$bins[$minIndex][] = $item['test'];
$binTimes[$minIndex] += $item['time'];
}
return $bins;
}
/**
* Calculates the median of an array of floats.
*
* @param list<float> $values
*/
private static function median(array $values): float
{
sort($values);
$count = count($values);
$middle = (int) floor($count / 2);
if ($count % 2 === 0) {
return ($values[$middle - 1] + $values[$middle]) / 2;
}
return $values[$middle];
}
/**
* Writes the timings to shards.json.
*
* @param array<string, float> $timings
*/
private static function writeTimings(array $timings): void
{
$path = self::shardsPath();
$directory = dirname($path);
if (! is_dir($directory)) {
mkdir($directory, 0755, true);
}
if (self::$knownTests !== null) {
$knownSet = array_flip(self::$knownTests);
$timings = array_intersect_key($timings, $knownSet);
}
ksort($timings);
$canonical = self::$knownTests ?? array_keys($timings);
sort($canonical);
file_put_contents($path, json_encode([
'timings' => $timings,
'checksum' => md5(implode("\n", $canonical)),
'updated_at' => date('c'),
], JSON_PRETTY_PRINT | JSON_UNESCAPED_SLASHES)."\n");
}
/** /**
* Returns the shard information. * Returns the shard information.
* *

View File

@ -0,0 +1,22 @@
<?php
declare(strict_types=1);
namespace Pest\Subscribers;
use PHPUnit\Event\TestSuite\Finished;
use PHPUnit\Event\TestSuite\FinishedSubscriber;
/**
* @internal
*/
final class EnsureShardTimingFinished implements FinishedSubscriber
{
/**
* Runs the subscriber.
*/
public function notify(Finished $event): void
{
EnsureShardTimingsAreCollected::finished($event);
}
}

View File

@ -0,0 +1,22 @@
<?php
declare(strict_types=1);
namespace Pest\Subscribers;
use PHPUnit\Event\TestSuite\Started;
use PHPUnit\Event\TestSuite\StartedSubscriber;
/**
* @internal
*/
final class EnsureShardTimingStarted implements StartedSubscriber
{
/**
* Runs the subscriber.
*/
public function notify(Started $event): void
{
EnsureShardTimingsAreCollected::started($event);
}
}

View File

@ -0,0 +1,75 @@
<?php
declare(strict_types=1);
namespace Pest\Subscribers;
use PHPUnit\Event\Telemetry\HRTime;
use PHPUnit\Event\TestSuite\Finished;
use PHPUnit\Event\TestSuite\Started;
/**
* @internal
*/
final class EnsureShardTimingsAreCollected
{
/**
* The start times for each test class.
*
* @var array<string, HRTime>
*/
private static array $startTimes = [];
/**
* The collected timings for each test class.
*
* @var array<string, float>
*/
private static array $timings = [];
/**
* Records the start time for a test suite.
*/
public static function started(Started $event): void
{
if (! $event->testSuite()->isForTestClass()) {
return;
}
$name = preg_replace('/^P\\\\/', '', $event->testSuite()->name());
if (is_string($name)) {
self::$startTimes[$name] = $event->telemetryInfo()->time();
}
}
/**
* Records the duration for a test suite.
*/
public static function finished(Finished $event): void
{
if (! $event->testSuite()->isForTestClass()) {
return;
}
$name = preg_replace('/^P\\\\/', '', $event->testSuite()->name());
if (! is_string($name) || ! isset(self::$startTimes[$name])) {
return;
}
$duration = $event->telemetryInfo()->time()->duration(self::$startTimes[$name]);
self::$timings[$name] = round($duration->asFloat(), 4);
}
/**
* Returns the collected timings.
*
* @return array<string, float>
*/
public static function timings(): array
{
return self::$timings;
}
}

View File

@ -49,6 +49,7 @@
EXECUTION OPTIONS: EXECUTION OPTIONS:
--parallel ........................................... Run tests in parallel --parallel ........................................... Run tests in parallel
--update-snapshots Update snapshots for tests using the "toMatchSnapshot" expectation --update-snapshots Update snapshots for tests using the "toMatchSnapshot" expectation
--update-shards Update shards.json with test timing data for time-balanced sharding
--globals-backup ................. Backup and restore $GLOBALS for each test --globals-backup ................. Backup and restore $GLOBALS for each test
--static-backup ......... Backup and restore static properties for each test --static-backup ......... Backup and restore static properties for each test
--strict-coverage ................... Be strict about code coverage metadata --strict-coverage ................... Be strict about code coverage metadata

View File

@ -17,6 +17,7 @@ arch()->preset()->security()->ignoring([
'eval', 'eval',
'str_shuffle', 'str_shuffle',
'exec', 'exec',
'md5',
'unserialize', 'unserialize',
'extract', 'extract',
'assert', 'assert',