Coverage for app / backend / src / tests / pytest_split / algorithms.py: 73%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-03 06:18 +0000

1# Vendored from https://github.com/jerry-git/pytest-split 

2 

3# Copyright (c) 2024 Jerry Pussinen 

4# 

5# Permission is hereby granted, free of charge, to any person obtaining 

6# a copy of this software and associated documentation files (the 

7# "Software"), to deal in the Software without restriction, including 

8# without limitation the rights to use, copy, modify, merge, publish, 

9# distribute, sublicense, and/or sell copies of the Software, and to 

10# permit persons to whom the Software is furnished to do so, subject to 

11# the following conditions: 

12# 

13# The above copyright notice and this permission notice shall be included 

14# in all copies or substantial portions of the Software. 

15# 

16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 

17# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 

18# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 

19# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 

20# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 

21# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 

22# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 

23from __future__ import annotations 

24 

25import enum 

26import heapq 

27from abc import ABC, abstractmethod 

28from operator import itemgetter 

29from typing import TYPE_CHECKING, NamedTuple 

30 

31if TYPE_CHECKING: 

32 from _pytest import nodes 

33 

34 

35class TestGroup(NamedTuple): 

36 selected: list[nodes.Item] 

37 deselected: list[nodes.Item] 

38 duration: float 

39 

40 

41class AlgorithmBase(ABC): 

42 """Abstract base class for the algorithm implementations.""" 

43 

44 @abstractmethod 

45 def __call__(self, splits: int, items: list[nodes.Item], durations: dict[str, float]) -> list[TestGroup]: 

46 pass 

47 

48 def __hash__(self) -> int: 

49 return hash(self.__class__.__name__) 

50 

51 def __eq__(self, other: object) -> bool: 

52 if not isinstance(other, AlgorithmBase): 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 return NotImplemented 

54 return self.__class__.__name__ == other.__class__.__name__ 

55 

56 

57class LeastDurationAlgorithm(AlgorithmBase): 

58 """ 

59 Split tests into groups by runtime. 

60 It walks the test items, starting with the test with largest duration. 

61 It assigns the test with the largest runtime to the group with the smallest duration sum. 

62 

63 The algorithm sorts the items by their duration. Since the sorting algorithm is stable, ties will be broken by 

64 maintaining the original order of items. It is therefore important that the order of items be identical on all nodes 

65 that use this plugin. Due to issue #25 this might not always be the case. 

66 

67 :param splits: How many groups we're splitting in. 

68 :param items: Test items passed down by Pytest. 

69 :param durations: Our cached test runtimes. Assumes contains timings only of relevant tests 

70 :return: 

71 List of groups 

72 """ 

73 

74 def __call__(self, splits: int, items: list[nodes.Item], durations: dict[str, float]) -> list[TestGroup]: 

75 items_with_durations = _get_items_with_durations(items, durations) 

76 

77 # add index of item in list 

78 items_with_durations_indexed = [(*tup, i) for i, tup in enumerate(items_with_durations)] 

79 

80 # Sort by name to ensure it's always the same order 

81 items_with_durations_indexed = sorted(items_with_durations_indexed, key=lambda tup: str(tup[0])) 

82 

83 # sort in ascending order 

84 sorted_items_with_durations = sorted(items_with_durations_indexed, key=lambda tup: tup[1], reverse=True) 

85 

86 selected: list[list[tuple[nodes.Item, int]]] = [[] for _ in range(splits)] 

87 deselected: list[list[nodes.Item]] = [[] for _ in range(splits)] 

88 duration: list[float] = [0 for _ in range(splits)] 

89 

90 # create a heap of the form (summed_durations, group_index) 

91 heap: list[tuple[float, int]] = [(0, i) for i in range(splits)] 

92 heapq.heapify(heap) 

93 for item, item_duration, original_index in sorted_items_with_durations: 

94 # get group with smallest sum 

95 summed_durations, group_idx = heapq.heappop(heap) 

96 new_group_durations = summed_durations + item_duration 

97 

98 # store assignment 

99 selected[group_idx].append((item, original_index)) 

100 duration[group_idx] = new_group_durations 

101 for i in range(splits): 

102 if i != group_idx: 

103 deselected[i].append(item) 

104 

105 # store new duration - in case of ties it sorts by the group_idx 

106 heapq.heappush(heap, (new_group_durations, group_idx)) 

107 

108 groups = [] 

109 for i in range(splits): 

110 # sort the items by their original index to maintain relative ordering 

111 # we don't care about the order of deselected items 

112 s = [item for item, original_index in sorted(selected[i], key=lambda tup: tup[1])] 

113 group = TestGroup(selected=s, deselected=deselected[i], duration=duration[i]) 

114 groups.append(group) 

115 return groups 

116 

117 

118class DurationBasedChunksAlgorithm(AlgorithmBase): 

119 """ 

120 Split tests into groups by runtime. 

121 Ensures tests are split into non-overlapping groups. 

122 The original list of test items is split into groups by finding boundary indices i_0, i_1, i_2 

123 and creating group_1 = items[0:i_0], group_2 = items[i_0, i_1], group_3 = items[i_1, i_2], ... 

124 

125 :param splits: How many groups we're splitting in. 

126 :param items: Test items passed down by Pytest. 

127 :param durations: Our cached test runtimes. Assumes contains timings only of relevant tests 

128 :return: List of TestGroup 

129 """ 

130 

131 def __call__(self, splits: int, items: list[nodes.Item], durations: dict[str, float]) -> list[TestGroup]: 

132 items_with_durations = _get_items_with_durations(items, durations) 

133 time_per_group = sum(map(itemgetter(1), items_with_durations)) / splits 

134 

135 selected: list[list[nodes.Item]] = [[] for i in range(splits)] 

136 deselected: list[list[nodes.Item]] = [[] for i in range(splits)] 

137 duration: list[float] = [0 for i in range(splits)] 

138 

139 group_idx = 0 

140 for item, item_duration in items_with_durations: 

141 if duration[group_idx] >= time_per_group: 

142 group_idx += 1 

143 

144 selected[group_idx].append(item) 

145 for i in range(splits): 

146 if i != group_idx: 

147 deselected[i].append(item) 

148 duration[group_idx] += item_duration 

149 

150 return [TestGroup(selected=selected[i], deselected=deselected[i], duration=duration[i]) for i in range(splits)] 

151 

152 

153def _get_items_with_durations(items: list[nodes.Item], durations: dict[str, float]) -> list[tuple[nodes.Item, float]]: 

154 durations = _remove_irrelevant_durations(items, durations) 

155 avg_duration_per_test = _get_avg_duration_per_test(durations) 

156 items_with_durations = [(item, durations.get(item.nodeid, avg_duration_per_test)) for item in items] 

157 return items_with_durations 

158 

159 

160def _get_avg_duration_per_test(durations: dict[str, float]) -> float: 

161 if durations: 161 ↛ 165line 161 didn't jump to line 165 because the condition on line 161 was always true

162 avg_duration_per_test = sum(durations.values()) / len(durations) 

163 else: 

164 # If there are no durations, give every test the same arbitrary value 

165 avg_duration_per_test = 1 

166 return avg_duration_per_test 

167 

168 

169def _remove_irrelevant_durations(items: list[nodes.Item], durations: dict[str, float]) -> dict[str, float]: 

170 # Filtering down durations to relevant ones ensures the avg isn't skewed by irrelevant data 

171 test_ids = [item.nodeid for item in items] 

172 durations = {name: durations[name] for name in test_ids if name in durations} 

173 return durations 

174 

175 

176class Algorithms(enum.Enum): 

177 duration_based_chunks = DurationBasedChunksAlgorithm() 

178 least_duration = LeastDurationAlgorithm() 

179 

180 @staticmethod 

181 def names() -> list[str]: 

182 return [x.name for x in Algorithms]