← comparison · math_jngen.h
1#pragma once23// This header is named 'math_jngen.h' and not 'math.h' because in the latter4// case it will replace the standard 'math.h' if you set jngen folder as the5// include path.67#include "array.h"8#include "common.h"9#include "random.h"1011#include <algorithm>12#include <cmath>13#include <iterator>14#include <limits>15#include <type_traits>16#include <unordered_set>17#include <vector>1819namespace jngen {2021namespace detail {2223inline int multiply(int x, int y, int mod) {24    return static_cast<long long>(x) * y % mod;25}2627inline long long multiply(long long x, long long y, long long mod) {28#if defined(__SIZEOF_INT128__)29    return static_cast<__int128>(x) * y % mod;30#else31    long long res = 0;32    while (y) {33        if (y&1) {34            res = (static_cast<unsigned long long>(res) + x) % mod;35        }36        x = (static_cast<unsigned long long>(x) + x) % mod;37        y >>= 1;38    }39    return res;40#endif41}4243inline int power(int x, int k, int mod) {44    int res = 1;45    while (k) {46        if (k&1) {47            res = multiply(res, x, mod);48        }49        x = multiply(x, x, mod);50        k >>= 1;51    }52    return res;53}5455inline long long power(long long x, long long k, long long mod) {56    long long res = 1;57    while (k) {58        if (k&1) {59            res = multiply(res, x, mod);60        }61        x = multiply(x, x, mod);62        k >>= 1;63    }64    return res;65}6667template<typename I>68bool millerRabinTest(I n, const std::vector<I>& witnesses) {69    static_assert(70        std::is_same<I, int>::value || std::is_same<I, long long>::value,71        "millerRabinTest<int/long long> only is supported");7273    if (n == 1) {74        return false;75    }7677    constexpr int LIMIT = 10000;7879    if (n <= LIMIT) {80        for (int i = 2; i*i <= n; ++i) {81            if (n%i == 0) {82                return false;83            }84        }85        return true;86    }8788    int r = 0;89    I d = n - 1;90    while (d % 2 == 0) {91        ++r;92        d /= 2;93    }9495    for (I a: witnesses) {96        I x = power(a, d, n);97        if (x == 1 || x == n - 1) {98            continue;99        }100101        bool composite = true;102        for (int i = 0; i < r - 1; ++i) {103            x = multiply(x, x, n);104            if (x == 1) {105                return false;106            }107            if (x == n - 1) {108                i = r;109                composite = false;110                continue;111            }112        }113        if (composite) {114            return false;115        }116    }117    return true;118}119120} // namespace detail121122inline bool isPrime(long long n) {123    const static std::vector<int> INT_WITNESSES{2, 7, 61};124    const static std::vector<long long> LONG_LONG_WITNESSES125        {2, 3, 5, 7, 11, 13, 17, 19, 23};126    // todo: experiment with base127    // 2, 325, 9375, 28178, 450775, 9780504, and 1795265022128    // (guaranteed for all integers < 2^64)129130    // first strong pseudoprime to i64 bases is 3825123056546413051 ~= 3.8e18131    ensure(n > 0, "isPrime() is undefined for negative numbers");132    ensure(133        n <= static_cast<long long>(3.8e18),134        "isPrime() supports only numbers not greater than 3.8 * 10^18");135136    if (n < std::numeric_limits<int>::max()) {137        return detail::millerRabinTest<int>(n, INT_WITNESSES);138    } else {139        return detail::millerRabinTest<long long>(n, LONG_LONG_WITNESSES);140    }141}142143class MathRandom {144public:145    MathRandom() {146        static bool created = false;147        ensure(!created, "jngen::MathRandom should be created only once");148        created = true;149    }150151    static long long randomPrime(long long n) {152        ensure(n > 2, format("There are no primes below %lld", n));153        return randomPrime(2, n - 1);154    }155156    static long long randomPrime(long long l, long long r) {157        ensure(l <= r);158        std::unordered_set<long long> used;159        while (static_cast<long long>(used.size()) < r - l + 1) {160            long long x = rnd.next(l, r);161            if (used.count(x)) {162                continue;163            }164            used.insert(x);165            if (isPrime(x)) {166                return x;167            }168        }169        ensure(170            false,171            format(172                "There are no primes between %lld and %lld",173                l, r)174        );175    }176177    static long long nextPrime(long long n) {178        while (!isPrime(n)) {179            ++n;180        }181        return n;182    }183184    static long long previousPrime(long long n) {185        ensure(n >= 2, format("There are no primes less or equal to %lld", n));186        while (!isPrime(n)) {187            --n;188        }189        return n;190    }191192    static Array partition(193            int n,194            int numParts,195            int minSize = 0,196            int maxSize = -1)197    {198        auto res = partition(199            static_cast<long long>(n),200            numParts,201            static_cast<long long>(minSize),202            static_cast<long long>(maxSize));203        return Array(res.begin(), res.end());204    }205206    static Array64 partition(207            long long n,208            int numParts,209            long long minSize = 0,210            long long maxSize = -1)211    {212        if (maxSize == -1) {213            maxSize = n;214        }215216        ensure(n >= 0);217        ensure(numParts >= 0);218        ensure(numParts * minSize <= n, "minSize is too large");219        ensure(numParts * maxSize >= n, "maxSize is too small");220        ensure(minSize <= maxSize);221222        n -= minSize * numParts;223224        auto delimiters = Array64::random(225                numParts - 1, 0, n).sorted();226        delimiters.insert(delimiters.begin(), 0);227        delimiters.push_back(n);228229        Array64 partition(numParts);230        for (long long i = 0; i < numParts; ++i) {231            partition[i] = delimiters[i + 1] - delimiters[i];232        }233        partition.sort().reverse();234235        long long remaining = 0;236237        long long localMax = maxSize - minSize;238        for (auto& x: partition) {239            if (x > localMax) {240                remaining += x - localMax;241                x = localMax;242            }243244            x += minSize;245        }246247        // Here we try to distribute the remaining part in some even manner248        // between remaining slots. Looks like crap anyway, need a smarter way.249250        for (int divisor: { 2, 1 }) {251            partition.shuffle();252            for (auto& x: partition) {253                if (x < maxSize) {254                    long long add = std::min(255                            remaining, (maxSize - x) / divisor);256                    x += add;257                    remaining -= add;258                }259            }260        }261262        ensure(remaining == 0, "maxSize is too small");263264        return partition;265    }266267    template<typename T>268    TArray<TArray<T>> partition(269            TArray<T> elements,270            int numParts,271            int minSize = 0,272            int maxSize = -1)273    {274        return partition(275            std::move(elements),276            partition(277                static_cast<int>(elements.size()),278                numParts,279                minSize,280                maxSize));281    }282283    template<typename T>284    TArray<TArray<T>> partition(TArray<T> elements, const Array& sizes) {285        size_t total = std::accumulate(sizes.begin(), sizes.end(), size_t(0));286        ensure(total == elements.size(), "sum(sizes) != elements.size()");287        elements.shuffle();288        TArray<TArray<T>> res;289        auto it = elements.begin();290        for (int size: sizes) {291            res.emplace_back();292            std::copy(it, it + size, std::back_inserter(res.back()));293            it += size;294        }295296        return res;297    }298};299300JNGEN_EXTERN MathRandom rndm;301302} // namespace jngen303304using jngen::isPrime;305306using jngen::rndm;