Coverage for app / backend / src / tests / pytest_split / algorithms.py: 73%
83 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-03 06:18 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-03 06:18 +0000
1# Vendored from https://github.com/jerry-git/pytest-split
3# Copyright (c) 2024 Jerry Pussinen
4#
5# Permission is hereby granted, free of charge, to any person obtaining
6# a copy of this software and associated documentation files (the
7# "Software"), to deal in the Software without restriction, including
8# without limitation the rights to use, copy, modify, merge, publish,
9# distribute, sublicense, and/or sell copies of the Software, and to
10# permit persons to whom the Software is furnished to do so, subject to
11# the following conditions:
12#
13# The above copyright notice and this permission notice shall be included
14# in all copies or substantial portions of the Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
21# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
22# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23from __future__ import annotations
25import enum
26import heapq
27from abc import ABC, abstractmethod
28from operator import itemgetter
29from typing import TYPE_CHECKING, NamedTuple
31if TYPE_CHECKING:
32 from _pytest import nodes
35class TestGroup(NamedTuple):
36 selected: list[nodes.Item]
37 deselected: list[nodes.Item]
38 duration: float
41class AlgorithmBase(ABC):
42 """Abstract base class for the algorithm implementations."""
44 @abstractmethod
45 def __call__(self, splits: int, items: list[nodes.Item], durations: dict[str, float]) -> list[TestGroup]:
46 pass
48 def __hash__(self) -> int:
49 return hash(self.__class__.__name__)
51 def __eq__(self, other: object) -> bool:
52 if not isinstance(other, AlgorithmBase): 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 return NotImplemented
54 return self.__class__.__name__ == other.__class__.__name__
57class LeastDurationAlgorithm(AlgorithmBase):
58 """
59 Split tests into groups by runtime.
60 It walks the test items, starting with the test with largest duration.
61 It assigns the test with the largest runtime to the group with the smallest duration sum.
63 The algorithm sorts the items by their duration. Since the sorting algorithm is stable, ties will be broken by
64 maintaining the original order of items. It is therefore important that the order of items be identical on all nodes
65 that use this plugin. Due to issue #25 this might not always be the case.
67 :param splits: How many groups we're splitting in.
68 :param items: Test items passed down by Pytest.
69 :param durations: Our cached test runtimes. Assumes contains timings only of relevant tests
70 :return:
71 List of groups
72 """
74 def __call__(self, splits: int, items: list[nodes.Item], durations: dict[str, float]) -> list[TestGroup]:
75 items_with_durations = _get_items_with_durations(items, durations)
77 # add index of item in list
78 items_with_durations_indexed = [(*tup, i) for i, tup in enumerate(items_with_durations)]
80 # Sort by name to ensure it's always the same order
81 items_with_durations_indexed = sorted(items_with_durations_indexed, key=lambda tup: str(tup[0]))
83 # sort in ascending order
84 sorted_items_with_durations = sorted(items_with_durations_indexed, key=lambda tup: tup[1], reverse=True)
86 selected: list[list[tuple[nodes.Item, int]]] = [[] for _ in range(splits)]
87 deselected: list[list[nodes.Item]] = [[] for _ in range(splits)]
88 duration: list[float] = [0 for _ in range(splits)]
90 # create a heap of the form (summed_durations, group_index)
91 heap: list[tuple[float, int]] = [(0, i) for i in range(splits)]
92 heapq.heapify(heap)
93 for item, item_duration, original_index in sorted_items_with_durations:
94 # get group with smallest sum
95 summed_durations, group_idx = heapq.heappop(heap)
96 new_group_durations = summed_durations + item_duration
98 # store assignment
99 selected[group_idx].append((item, original_index))
100 duration[group_idx] = new_group_durations
101 for i in range(splits):
102 if i != group_idx:
103 deselected[i].append(item)
105 # store new duration - in case of ties it sorts by the group_idx
106 heapq.heappush(heap, (new_group_durations, group_idx))
108 groups = []
109 for i in range(splits):
110 # sort the items by their original index to maintain relative ordering
111 # we don't care about the order of deselected items
112 s = [item for item, original_index in sorted(selected[i], key=lambda tup: tup[1])]
113 group = TestGroup(selected=s, deselected=deselected[i], duration=duration[i])
114 groups.append(group)
115 return groups
118class DurationBasedChunksAlgorithm(AlgorithmBase):
119 """
120 Split tests into groups by runtime.
121 Ensures tests are split into non-overlapping groups.
122 The original list of test items is split into groups by finding boundary indices i_0, i_1, i_2
123 and creating group_1 = items[0:i_0], group_2 = items[i_0, i_1], group_3 = items[i_1, i_2], ...
125 :param splits: How many groups we're splitting in.
126 :param items: Test items passed down by Pytest.
127 :param durations: Our cached test runtimes. Assumes contains timings only of relevant tests
128 :return: List of TestGroup
129 """
131 def __call__(self, splits: int, items: list[nodes.Item], durations: dict[str, float]) -> list[TestGroup]:
132 items_with_durations = _get_items_with_durations(items, durations)
133 time_per_group = sum(map(itemgetter(1), items_with_durations)) / splits
135 selected: list[list[nodes.Item]] = [[] for i in range(splits)]
136 deselected: list[list[nodes.Item]] = [[] for i in range(splits)]
137 duration: list[float] = [0 for i in range(splits)]
139 group_idx = 0
140 for item, item_duration in items_with_durations:
141 if duration[group_idx] >= time_per_group:
142 group_idx += 1
144 selected[group_idx].append(item)
145 for i in range(splits):
146 if i != group_idx:
147 deselected[i].append(item)
148 duration[group_idx] += item_duration
150 return [TestGroup(selected=selected[i], deselected=deselected[i], duration=duration[i]) for i in range(splits)]
153def _get_items_with_durations(items: list[nodes.Item], durations: dict[str, float]) -> list[tuple[nodes.Item, float]]:
154 durations = _remove_irrelevant_durations(items, durations)
155 avg_duration_per_test = _get_avg_duration_per_test(durations)
156 items_with_durations = [(item, durations.get(item.nodeid, avg_duration_per_test)) for item in items]
157 return items_with_durations
160def _get_avg_duration_per_test(durations: dict[str, float]) -> float:
161 if durations: 161 ↛ 165line 161 didn't jump to line 165 because the condition on line 161 was always true
162 avg_duration_per_test = sum(durations.values()) / len(durations)
163 else:
164 # If there are no durations, give every test the same arbitrary value
165 avg_duration_per_test = 1
166 return avg_duration_per_test
169def _remove_irrelevant_durations(items: list[nodes.Item], durations: dict[str, float]) -> dict[str, float]:
170 # Filtering down durations to relevant ones ensures the avg isn't skewed by irrelevant data
171 test_ids = [item.nodeid for item in items]
172 durations = {name: durations[name] for name in test_ids if name in durations}
173 return durations
176class Algorithms(enum.Enum):
177 duration_based_chunks = DurationBasedChunksAlgorithm()
178 least_duration = LeastDurationAlgorithm()
180 @staticmethod
181 def names() -> list[str]:
182 return [x.name for x in Algorithms]