Grok 10.0.5
algo-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Normal include guard for target-independent parts
17#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
18#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
19
20#include <stdint.h>
21#include <string.h> // memcpy
22
23#include <algorithm> // std::sort, std::min, std::max
24#include <functional> // std::less, std::greater
25#include <thread> // NOLINT
26#include <vector>
27
28#include "hwy/base.h"
30
31// Third-party algorithms
32#define HAVE_AVX2SORT 0
33#define HAVE_IPS4O 0
34// When enabling, consider changing max_threads (required for Table 1a)
35#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1)
36#define HAVE_PDQSORT 0
37#define HAVE_SORT512 0
38#define HAVE_VXSORT 0
39
40#if HAVE_AVX2SORT
41HWY_PUSH_ATTRIBUTES("avx2,avx")
42#include "avx2sort.h" //NOLINT
44#endif
45#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O
46#include "third_party/ips4o/include/ips4o.hpp"
47#include "third_party/ips4o/include/ips4o/thread_pool.hpp"
48#endif
49#if HAVE_PDQSORT
50#include "third_party/boost/allowed/sort/sort.hpp"
51#endif
52#if HAVE_SORT512
53#include "sort512.h" //NOLINT
54#endif
55
56// vxsort is difficult to compile for multiple targets because it also uses
57// .cpp files, and we'd also have to #undef its include guards. Instead, compile
58// only for AVX2 or AVX3 depending on this macro.
59#define VXSORT_AVX3 1
60#if HAVE_VXSORT
61// inlined from vxsort_targets_enable_avx512 (must close before end of header)
62#ifdef __GNUC__
63#ifdef __clang__
64#if VXSORT_AVX3
65#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \
66 apply_to = any(function))
67#else
68#pragma clang attribute push(__attribute__((target("avx2"))), \
69 apply_to = any(function))
70#endif // VXSORT_AVX3
71
72#else
73#pragma GCC push_options
74#if VXSORT_AVX3
75#pragma GCC target("avx512f,avx512dq")
76#else
77#pragma GCC target("avx2")
78#endif // VXSORT_AVX3
79#endif
80#endif
81
82#if VXSORT_AVX3
83#include "vxsort/machine_traits.avx512.h"
84#else
85#include "vxsort/machine_traits.avx2.h"
86#endif // VXSORT_AVX3
87#include "vxsort/vxsort.h"
88#ifdef __GNUC__
89#ifdef __clang__
90#pragma clang attribute pop
91#else
92#pragma GCC pop_options
93#endif
94#endif
95#endif // HAVE_VXSORT
96
97namespace hwy {
98
100
101static inline std::vector<Dist> AllDist() {
102 return {/*Dist::kUniform8, Dist::kUniform16,*/ Dist::kUniform32};
103}
104
105static inline const char* DistName(Dist dist) {
106 switch (dist) {
107 case Dist::kUniform8:
108 return "uniform8";
109 case Dist::kUniform16:
110 return "uniform16";
111 case Dist::kUniform32:
112 return "uniform32";
113 }
114 return "unreachable";
115}
116
117template <typename T>
119 public:
120 void Notify(T value) {
121 min_ = std::min(min_, value);
122 max_ = std::max(max_, value);
123 // Converting to integer would truncate floats, multiplying to save digits
124 // risks overflow especially when casting, so instead take the sum of the
125 // bit representations as the checksum.
126 uint64_t bits = 0;
127 static_assert(sizeof(T) <= 8, "Expected a built-in type");
128 CopyBytes<sizeof(T)>(&value, &bits); // not same size
129 sum_ += bits;
130 count_ += 1;
131 }
132
133 bool operator==(const InputStats& other) const {
134 if (count_ != other.count_) {
135 HWY_ABORT("count %d vs %d\n", static_cast<int>(count_),
136 static_cast<int>(other.count_));
137 }
138
139 if (min_ != other.min_ || max_ != other.max_) {
140 HWY_ABORT("minmax %f/%f vs %f/%f\n", static_cast<double>(min_),
141 static_cast<double>(max_), static_cast<double>(other.min_),
142 static_cast<double>(other.max_));
143 }
144
145 // Sum helps detect duplicated/lost values
146 if (sum_ != other.sum_) {
147 HWY_ABORT("Sum mismatch %g %g; min %g max %g\n",
148 static_cast<double>(sum_), static_cast<double>(other.sum_),
149 static_cast<double>(min_), static_cast<double>(max_));
150 }
151
152 return true;
153 }
154
155 private:
156 T min_ = hwy::HighestValue<T>();
157 T max_ = hwy::LowestValue<T>();
158 uint64_t sum_ = 0;
159 size_t count_ = 0;
160};
161
162enum class Algo {
163#if HAVE_AVX2SORT
164 kSEA,
165#endif
166#if HAVE_IPS4O
167 kIPS4O,
168#endif
169#if HAVE_PARALLEL_IPS4O
170 kParallelIPS4O,
171#endif
172#if HAVE_PDQSORT
173 kPDQ,
174#endif
175#if HAVE_SORT512
176 kSort512,
177#endif
178#if HAVE_VXSORT
179 kVXSort,
180#endif
181 kStd,
182 kVQSort,
183 kHeap,
184};
185
186static inline const char* AlgoName(Algo algo) {
187 switch (algo) {
188#if HAVE_AVX2SORT
189 case Algo::kSEA:
190 return "sea";
191#endif
192#if HAVE_IPS4O
193 case Algo::kIPS4O:
194 return "ips4o";
195#endif
196#if HAVE_PARALLEL_IPS4O
197 case Algo::kParallelIPS4O:
198 return "par_ips4o";
199#endif
200#if HAVE_PDQSORT
201 case Algo::kPDQ:
202 return "pdq";
203#endif
204#if HAVE_SORT512
205 case Algo::kSort512:
206 return "sort512";
207#endif
208#if HAVE_VXSORT
209 case Algo::kVXSort:
210 return "vxsort";
211#endif
212 case Algo::kStd:
213 return "std";
214 case Algo::kVQSort:
215 return "vq";
216 case Algo::kHeap:
217 return "heap";
218 }
219 return "unreachable";
220}
221
222} // namespace hwy
223#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
224
225// Per-target
226#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == \
227 defined(HWY_TARGET_TOGGLE)
228#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
229#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
230#else
231#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
232#endif
233
236#include "hwy/contrib/sort/vqsort-inl.h" // HeapSort
238
240namespace hwy {
241namespace HWY_NAMESPACE {
242
244 static HWY_INLINE uint64_t SplitMix64(uint64_t z) {
245 z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull;
246 z = (z ^ (z >> 27)) * 0x94D049BB133111EBull;
247 return z ^ (z >> 31);
248 }
249
250 public:
251 // Generates two vectors of 64-bit seeds via SplitMix64 and stores into
252 // `seeds`. Generating these afresh in each ChoosePivot is too expensive.
253 template <class DU64>
254 static void GenerateSeeds(DU64 du64, TFromD<DU64>* HWY_RESTRICT seeds) {
255 seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull);
256 for (size_t i = 1; i < 2 * Lanes(du64); ++i) {
257 seeds[i] = SplitMix64(seeds[i - 1]);
258 }
259 }
260
261 // Need to pass in the state because vector cannot be class members.
262 template <class VU64>
263 static VU64 RandomBits(VU64& state0, VU64& state1) {
264 VU64 s1 = state0;
265 VU64 s0 = state1;
266 const VU64 bits = Add(s1, s0);
267 state0 = s0;
268 s1 = Xor(s1, ShiftLeft<23>(s1));
269 state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0))));
270 return bits;
271 }
272};
273
274template <class D, class VU64, HWY_IF_NOT_FLOAT_D(D)>
275Vec<D> RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) {
276 const VU64 bits = Xorshift128Plus::RandomBits(s0, s1);
277 return BitCast(d, And(bits, mask));
278}
279
280// It is important to avoid denormals, which are flushed to zero by SIMD but not
281// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD.
282template <class DF, class VU64, HWY_IF_FLOAT_D(DF)>
283Vec<DF> RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) {
284 using TF = TFromD<DF>;
285 const RebindToUnsigned<decltype(df)> du;
286 using VU = Vec<decltype(du)>;
287
288 const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask);
289
290#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types
291 using TU = MakeUnsigned<TF>;
292 const VU bits = Set(du, static_cast<TU>(GetLane(bits64) & LimitsMax<TU>()));
293#else
294 const VU bits = BitCast(du, bits64);
295#endif
296 // Avoid NaN/denormal by only generating values in [1, 2), i.e. random
297 // mantissas with the exponent taken from the representation of 1.0.
298 const VU k1 = BitCast(du, Set(df, TF{1.0}));
299 const VU mantissa_mask = Set(du, MantissaMask<TF>());
300 const VU representation = OrAnd(k1, bits, mantissa_mask);
301 return BitCast(df, representation);
302}
303
304template <class DU64>
305Vec<DU64> MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) {
306 switch (sizeof_t) {
307 case 2:
308 return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull
309 : 0xFFFFFFFFFFFFFFFFull);
310 case 4:
311 return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull
312 : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull
313 : 0xFFFFFFFFFFFFFFFFull);
314 case 8:
315 return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull
316 : (dist == Dist::kUniform16) ? 0x000000000000FFFFull
317 : 0x00000000FFFFFFFFull);
318 default:
319 HWY_ABORT("Logic error");
320 return Zero(du64);
321 }
322}
323
324template <typename T>
325InputStats<T> GenerateInput(const Dist dist, T* v, size_t num) {
327 using VU64 = Vec<decltype(du64)>;
328 const size_t N64 = Lanes(du64);
329 auto seeds = hwy::AllocateAligned<uint64_t>(2 * N64);
330 Xorshift128Plus::GenerateSeeds(du64, seeds.get());
331 VU64 s0 = Load(du64, seeds.get());
332 VU64 s1 = Load(du64, seeds.get() + N64);
333
334#if HWY_TARGET == HWY_SCALAR
335 const Sisd<T> d;
336#else
337 const Repartition<T, decltype(du64)> d;
338#endif
339 using V = Vec<decltype(d)>;
340 const size_t N = Lanes(d);
341 const VU64 mask = MaskForDist(du64, dist, sizeof(T));
342 auto buf = hwy::AllocateAligned<T>(N);
343
344 size_t i = 0;
345 for (; i + N <= num; i += N) {
346 const V values = RandomValues(d, s0, s1, mask);
347 StoreU(values, d, v + i);
348 }
349 if (i < num) {
350 const V values = RandomValues(d, s0, s1, mask);
351 StoreU(values, d, buf.get());
352 memcpy(v + i, buf.get(), (num - i) * sizeof(T));
353 }
354
355 InputStats<T> input_stats;
356 for (size_t i = 0; i < num; ++i) {
357 input_stats.Notify(v[i]);
358 }
359 return input_stats;
360}
361
365
367#if HAVE_PARALLEL_IPS4O
368 const unsigned max_threads = hwy::LimitsMax<unsigned>(); // 16 for Table 1a
369 ips4o::StdThreadPool pool{static_cast<int>(
370 HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))};
371#endif
372 std::vector<ThreadLocal> tls{1};
373};
374
375// Bridge from keys (passed to Run) to lanes as expected by HeapSort. For
376// non-128-bit keys they are the same:
377template <class Order, typename KeyType, HWY_IF_NOT_LANE_SIZE(KeyType, 16)>
378void CallHeapSort(KeyType* HWY_RESTRICT keys, const size_t num_keys) {
379 using detail::TraitsLane;
381 if (Order().IsAscending()) {
382 const SharedTraits<TraitsLane<detail::OrderAscending<KeyType>>> st;
383 return detail::HeapSort(st, keys, num_keys);
384 } else {
385 const SharedTraits<TraitsLane<detail::OrderDescending<KeyType>>> st;
386 return detail::HeapSort(st, keys, num_keys);
387 }
388}
389
390#if VQSORT_ENABLED
391template <class Order>
392void CallHeapSort(hwy::uint128_t* HWY_RESTRICT keys, const size_t num_keys) {
393 using detail::SharedTraits;
394 using detail::Traits128;
395 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
396 const size_t num_lanes = num_keys * 2;
397 if (Order().IsAscending()) {
398 const SharedTraits<Traits128<detail::OrderAscending128>> st;
399 return detail::HeapSort(st, lanes, num_lanes);
400 } else {
401 const SharedTraits<Traits128<detail::OrderDescending128>> st;
402 return detail::HeapSort(st, lanes, num_lanes);
403 }
404}
405
406template <class Order>
407void CallHeapSort(K64V64* HWY_RESTRICT keys, const size_t num_keys) {
408 using detail::SharedTraits;
409 using detail::Traits128;
410 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
411 const size_t num_lanes = num_keys * 2;
412 if (Order().IsAscending()) {
413 const SharedTraits<Traits128<detail::OrderAscendingKV128>> st;
414 return detail::HeapSort(st, lanes, num_lanes);
415 } else {
416 const SharedTraits<Traits128<detail::OrderDescendingKV128>> st;
417 return detail::HeapSort(st, lanes, num_lanes);
418 }
419}
420#endif // VQSORT_ENABLED
421
422template <class Order, typename KeyType>
423void Run(Algo algo, KeyType* HWY_RESTRICT inout, size_t num,
424 SharedState& shared, size_t thread) {
425 const std::less<KeyType> less;
426 const std::greater<KeyType> greater;
427
428 switch (algo) {
429#if HAVE_AVX2SORT
430 case Algo::kSEA:
431 return avx2::quicksort(inout, static_cast<int>(num));
432#endif
433
434#if HAVE_IPS4O
435 case Algo::kIPS4O:
436 if (Order().IsAscending()) {
437 return ips4o::sort(inout, inout + num, less);
438 } else {
439 return ips4o::sort(inout, inout + num, greater);
440 }
441#endif
442
443#if HAVE_PARALLEL_IPS4O
444 case Algo::kParallelIPS4O:
445 if (Order().IsAscending()) {
446 return ips4o::parallel::sort(inout, inout + num, less, shared.pool);
447 } else {
448 return ips4o::parallel::sort(inout, inout + num, greater, shared.pool);
449 }
450#endif
451
452#if HAVE_SORT512
453 case Algo::kSort512:
454 HWY_ABORT("not supported");
455 // return Sort512::Sort(inout, num);
456#endif
457
458#if HAVE_PDQSORT
459 case Algo::kPDQ:
460 if (Order().IsAscending()) {
461 return boost::sort::pdqsort_branchless(inout, inout + num, less);
462 } else {
463 return boost::sort::pdqsort_branchless(inout, inout + num, greater);
464 }
465#endif
466
467#if HAVE_VXSORT
468 case Algo::kVXSort: {
469#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \
470 (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2)
471 fprintf(stderr, "Do not call for target %s\n",
473 return;
474#else
475#if VXSORT_AVX3
476 vxsort::vxsort<KeyType, vxsort::AVX512> vx;
477#else
478 vxsort::vxsort<KeyType, vxsort::AVX2> vx;
479#endif
480 if (Order().IsAscending()) {
481 return vx.sort(inout, inout + num - 1);
482 } else {
483 fprintf(stderr, "Skipping VX - does not support descending order\n");
484 return;
485 }
486#endif // enabled for this target
487 }
488#endif // HAVE_VXSORT
489
490 case Algo::kStd:
491 if (Order().IsAscending()) {
492 return std::sort(inout, inout + num, less);
493 } else {
494 return std::sort(inout, inout + num, greater);
495 }
496
497 case Algo::kVQSort:
498 return shared.tls[thread].sorter(inout, num, Order());
499
500 case Algo::kHeap:
501 return CallHeapSort<Order>(inout, num);
502
503 default:
504 HWY_ABORT("Not implemented");
505 }
506}
507
508// NOLINTNEXTLINE(google-readability-namespace-comments)
509} // namespace HWY_NAMESPACE
510} // namespace hwy
512
513#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
#define HWY_RESTRICT
Definition base.h:64
#define HWY_POP_ATTRIBUTES
Definition base.h:123
#define HWY_MIN(a, b)
Definition base.h:134
#define HWY_ABORT(format,...)
Definition base.h:188
#define HWY_INLINE
Definition base.h:70
#define HWY_PUSH_ATTRIBUTES(targets_str)
Definition base.h:122
Definition algo-inl.h:243
static void GenerateSeeds(DU64 du64, TFromD< DU64 > *HWY_RESTRICT seeds)
Definition algo-inl.h:254
static HWY_INLINE uint64_t SplitMix64(uint64_t z)
Definition algo-inl.h:244
static VU64 RandomBits(VU64 &state0, VU64 &state1)
Definition algo-inl.h:263
Definition algo-inl.h:118
T min_
Definition algo-inl.h:156
size_t count_
Definition algo-inl.h:159
T max_
Definition algo-inl.h:157
bool operator==(const InputStats &other) const
Definition algo-inl.h:133
void Notify(T value)
Definition algo-inl.h:120
uint64_t sum_
Definition algo-inl.h:158
Definition vqsort.h:41
#define HWY_TARGET
Definition detect_targets.h:380
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition vqsort-inl.h:127
d
Definition rvv-inl.h:1998
InputStats< T > GenerateInput(const Dist dist, T *v, size_t num)
Definition algo-inl.h:325
void CallHeapSort(KeyType *HWY_RESTRICT keys, const size_t num_keys)
Definition algo-inl.h:378
void Run(Algo algo, KeyType *HWY_RESTRICT inout, size_t num, SharedState &shared, size_t thread)
Definition algo-inl.h:423
HWY_API Vec128< T, N > And(const Vec128< T, N > a, const Vec128< T, N > b)
Definition arm_neon-inl.h:1949
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition ops/shared-inl.h:212
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition arm_sve-inl.h:243
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition arm_neon-inl.h:2753
Vec< DU64 > MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t)
Definition algo-inl.h:305
HWY_API Vec128< T, N > Xor(const Vec128< T, N > a, const Vec128< T, N > b)
Definition arm_neon-inl.h:1998
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:2772
svuint16_t Set(Simd< bfloat16_t, N, kPow2 > d, bfloat16_t arg)
Definition arm_sve-inl.h:322
HWY_API Vec128< T, N > OrAnd(Vec128< T, N > o, Vec128< T, N > a1, Vec128< T, N > a2)
Definition arm_neon-inl.h:2040
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition arm_neon-inl.h:997
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition arm_neon-inl.h:1020
HWY_API TFromV< V > GetLane(const V v)
Definition arm_neon-inl.h:1076
typename D::template Repartition< T > Repartition
Definition ops/shared-inl.h:218
N
Definition rvv-inl.h:1998
ScalableTag< T, -1 > SortTag
Definition contrib/sort/shared-inl.h:124
Vec< D > RandomValues(D d, VU64 &s0, VU64 &s1, const VU64 mask)
Definition algo-inl.h:275
const vfloat64m1_t v
Definition rvv-inl.h:1998
typename D::T TFromD
Definition ops/shared-inl.h:203
decltype(Zero(D())) Vec
Definition generic_ops-inl.h:40
Definition aligned_allocator.h:27
static const char * DistName(Dist dist)
Definition algo-inl.h:105
Dist
Definition algo-inl.h:99
static std::vector< Dist > AllDist()
Definition algo-inl.h:101
static const char * AlgoName(Algo algo)
Definition algo-inl.h:186
static HWY_MAYBE_UNUSED const char * TargetName(int64_t target)
Definition targets.h:85
Algo
Definition algo-inl.h:162
typename detail::Relations< T >::Unsigned MakeUnsigned
Definition base.h:593
#define HWY_NAMESPACE
Definition set_macros-inl.h:82
Definition algo-inl.h:366
std::vector< ThreadLocal > tls
Definition algo-inl.h:372
Definition ops/shared-inl.h:52
Definition algo-inl.h:362
Sorter sorter
Definition algo-inl.h:363
Definition sorting_networks-inl.h:698
Definition traits-inl.h:545
Definition base.h:309