1#pragma once23// This header is named 'math_jngen.h' and not 'math.h' because in the latter4// case it will replace the standard 'math.h' if you set jngen folder as the5// include path.67#include "array.h"8#include "common.h"9#include "random.h"1011#include <algorithm>12#include <cmath>13#include <iterator>14#include <limits>15#include <type_traits>16#include <unordered_set>17#include <vector>1819namespace jngen {2021namespace detail {2223inline int multiply(int x, int y, int mod) {24 return static_cast<long long>(x) * y % mod;25}2627inline long long multiply(long long x, long long y, long long mod) {28#if defined(__SIZEOF_INT128__)29 return static_cast<__int128>(x) * y % mod;30#else31 long long res = 0;32 while (y) {33 if (y&1) {34 res = (static_cast<unsigned long long>(res) + x) % mod;35 }36 x = (static_cast<unsigned long long>(x) + x) % mod;37 y >>= 1;38 }39 return res;40#endif41}4243inline int power(int x, int k, int mod) {44 int res = 1;45 while (k) {46 if (k&1) {47 res = multiply(res, x, mod);48 }49 x = multiply(x, x, mod);50 k >>= 1;51 }52 return res;53}5455inline long long power(long long x, long long k, long long mod) {56 long long res = 1;57 while (k) {58 if (k&1) {59 res = multiply(res, x, mod);60 }61 x = multiply(x, x, mod);62 k >>= 1;63 }64 return res;65}6667template<typename I>68bool millerRabinTest(I n, const std::vector<I>& witnesses) {69 static_assert(70 std::is_same<I, int>::value || std::is_same<I, long long>::value,71 "millerRabinTest<int/long long> only is supported");7273 if (n == 1) {74 return false;75 }7677 constexpr int LIMIT = 10000;7879 if (n <= LIMIT) {80 for (int i = 2; i*i <= n; ++i) {81 if (n%i == 0) {82 return false;83 }84 }85 return true;86 }8788 int r = 0;89 I d = n - 1;90 while (d % 2 == 0) {91 ++r;92 d /= 2;93 }9495 for (I a: witnesses) {96 I x = power(a, d, n);97 if (x == 1 || x == n - 1) {98 continue;99 }100101 bool composite = true;102 for (int i = 0; i < r - 1; ++i) {103 x = multiply(x, x, n);104 if (x == 1) {105 return false;106 }107 if (x == n - 1) {108 i = r;109 composite = false;110 continue;111 }112 }113 if (composite) {114 return false;115 }116 }117 return true;118}119120} // namespace detail121122inline bool isPrime(long long n) {123 const static std::vector<int> INT_WITNESSES{2, 7, 61};124 const static std::vector<long long> LONG_LONG_WITNESSES125 {2, 3, 5, 7, 11, 13, 17, 19, 23};126 // todo: experiment with base127 // 2, 325, 9375, 28178, 450775, 9780504, and 1795265022128 // (guaranteed for all integers < 2^64)129130 // first strong pseudoprime to i64 bases is 3825123056546413051 ~= 3.8e18131 ensure(n > 0, "isPrime() is undefined for negative numbers");132 ensure(133 n <= static_cast<long long>(3.8e18),134 "isPrime() supports only numbers not greater than 3.8 * 10^18");135136 if (n < std::numeric_limits<int>::max()) {137 return detail::millerRabinTest<int>(n, INT_WITNESSES);138 } else {139 return detail::millerRabinTest<long long>(n, LONG_LONG_WITNESSES);140 }141}142143class MathRandom {144public:145 MathRandom() {146 static bool created = false;147 ensure(!created, "jngen::MathRandom should be created only once");148 created = true;149 }150151 static long long randomPrime(long long n) {152 ensure(n > 2, format("There are no primes below %lld", n));153 return randomPrime(2, n - 1);154 }155156 static long long randomPrime(long long l, long long r) {157 ensure(l <= r);158 std::unordered_set<long long> used;159 while (static_cast<long long>(used.size()) < r - l + 1) {160 long long x = rnd.next(l, r);161 if (used.count(x)) {162 continue;163 }164 used.insert(x);165 if (isPrime(x)) {166 return x;167 }168 }169 ensure(170 false,171 format(172 "There are no primes between %lld and %lld",173 l, r)174 );175 }176177 static long long nextPrime(long long n) {178 while (!isPrime(n)) {179 ++n;180 }181 return n;182 }183184 static long long previousPrime(long long n) {185 ensure(n >= 2, format("There are no primes less or equal to %lld", n));186 while (!isPrime(n)) {187 --n;188 }189 return n;190 }191192 static Array partition(193 int n,194 int numParts,195 int minSize = 0,196 int maxSize = -1)197 {198 auto res = partition(199 static_cast<long long>(n),200 numParts,201 static_cast<long long>(minSize),202 static_cast<long long>(maxSize));203 return Array(res.begin(), res.end());204 }205206 static Array64 partition(207 long long n,208 int numParts,209 long long minSize = 0,210 long long maxSize = -1)211 {212 if (maxSize == -1) {213 maxSize = n;214 }215216 ensure(n >= 0);217 ensure(numParts >= 0);218 ensure(numParts * minSize <= n, "minSize is too large");219 ensure(numParts * maxSize >= n, "maxSize is too small");220 ensure(minSize <= maxSize);221222 n -= minSize * numParts;223224 auto delimiters = Array64::random(225 numParts - 1, 0, n).sorted();226 delimiters.insert(delimiters.begin(), 0);227 delimiters.push_back(n);228229 Array64 partition(numParts);230 for (long long i = 0; i < numParts; ++i) {231 partition[i] = delimiters[i + 1] - delimiters[i];232 }233 partition.sort().reverse();234235 long long remaining = 0;236237 long long localMax = maxSize - minSize;238 for (auto& x: partition) {239 if (x > localMax) {240 remaining += x - localMax;241 x = localMax;242 }243244 x += minSize;245 }246247 // Here we try to distribute the remaining part in some even manner248 // between remaining slots. Looks like crap anyway, need a smarter way.249250 for (int divisor: { 2, 1 }) {251 partition.shuffle();252 for (auto& x: partition) {253 if (x < maxSize) {254 long long add = std::min(255 remaining, (maxSize - x) / divisor);256 x += add;257 remaining -= add;258 }259 }260 }261262 ensure(remaining == 0, "maxSize is too small");263264 return partition;265 }266267 template<typename T>268 TArray<TArray<T>> partition(269 TArray<T> elements,270 int numParts,271 int minSize = 0,272 int maxSize = -1)273 {274 return partition(275 std::move(elements),276 partition(277 static_cast<int>(elements.size()),278 numParts,279 minSize,280 maxSize));281 }282283 template<typename T>284 TArray<TArray<T>> partition(TArray<T> elements, const Array& sizes) {285 size_t total = std::accumulate(sizes.begin(), sizes.end(), size_t(0));286 ensure(total == elements.size(), "sum(sizes) != elements.size()");287 elements.shuffle();288 TArray<TArray<T>> res;289 auto it = elements.begin();290 for (int size: sizes) {291 res.emplace_back();292 std::copy(it, it + size, std::back_inserter(res.back()));293 it += size;294 }295296 return res;297 }298};299300JNGEN_EXTERN MathRandom rndm;301302} // namespace jngen303304using jngen::isPrime;305306using jngen::rndm;