random: add RandomMixin::randbits with compile-known bits

In many cases, it is known at compile time how many bits are requested from
randbits. Provide a variant of randbits that accepts this number as a template,
to make sure the compiler can make use of this knowledge. This is used immediately
in rand32() and randbool(), and a few further call sites.
This commit is contained in:
Pieter Wuille 2024-03-10 12:38:14 -04:00
parent 21ce9d8658
commit ddb7d26cfd
6 changed files with 52 additions and 9 deletions

View File

@ -776,7 +776,7 @@ std::pair<CAddress, NodeSeconds> AddrManImpl::Select_(bool new_only, std::option
const AddrInfo& info{it_found->second}; const AddrInfo& info{it_found->second};
// With probability GetChance() * chance_factor, return the entry. // With probability GetChance() * chance_factor, return the entry.
if (insecure_rand.randbits(30) < chance_factor * info.GetChance() * (1 << 30)) { if (insecure_rand.randbits<30>() < chance_factor * info.GetChance() * (1 << 30)) {
LogPrint(BCLog::ADDRMAN, "Selected %s from %s\n", info.ToStringAddrPort(), search_tried ? "tried" : "new"); LogPrint(BCLog::ADDRMAN, "Selected %s from %s\n", info.ToStringAddrPort(), search_tried ? "tried" : "new");
return {info, info.m_last_try}; return {info, info.m_last_try};
} }

View File

@ -741,6 +741,6 @@ void RandomInit()
std::chrono::microseconds GetExponentialRand(std::chrono::microseconds now, std::chrono::seconds average_interval) std::chrono::microseconds GetExponentialRand(std::chrono::microseconds now, std::chrono::seconds average_interval)
{ {
double unscaled = -std::log1p(GetRand(uint64_t{1} << 48) * -0.0000000000000035527136788 /* -1/2^48 */); double unscaled = -std::log1p(FastRandomContext().randbits<48>() * -0.0000000000000035527136788 /* -1/2^48 */);
return now + std::chrono::duration_cast<std::chrono::microseconds>(unscaled * average_interval + 0.5us); return now + std::chrono::duration_cast<std::chrono::microseconds>(unscaled * average_interval + 0.5us);
} }

View File

@ -223,6 +223,30 @@ public:
return ret & ((uint64_t{1} << bits) - 1); return ret & ((uint64_t{1} << bits) - 1);
} }
/** Same as above, but with compile-time fixed bits count. */
template<int Bits>
uint64_t randbits() noexcept
{
static_assert(Bits >= 0 && Bits <= 64);
if constexpr (Bits == 64) {
return Impl().rand64();
} else {
uint64_t ret;
if (Bits <= bitbuf_size) {
ret = bitbuf;
bitbuf >>= Bits;
bitbuf_size -= Bits;
} else {
uint64_t gen = Impl().rand64();
ret = (gen << bitbuf_size) | bitbuf;
bitbuf = gen >> (Bits - bitbuf_size);
bitbuf_size = 64 + bitbuf_size - Bits;
}
constexpr uint64_t MASK = (uint64_t{1} << Bits) - 1;
return ret & MASK;
}
}
/** Generate a random integer in the range [0..range). /** Generate a random integer in the range [0..range).
* Precondition: range > 0. * Precondition: range > 0.
*/ */
@ -247,7 +271,7 @@ public:
} }
/** Generate a random 32-bit integer. */ /** Generate a random 32-bit integer. */
uint32_t rand32() noexcept { return Impl().randbits(32); } uint32_t rand32() noexcept { return Impl().template randbits<32>(); }
/** generate a random uint256. */ /** generate a random uint256. */
uint256 rand256() noexcept uint256 rand256() noexcept
@ -258,7 +282,7 @@ public:
} }
/** Generate a random boolean. */ /** Generate a random boolean. */
bool randbool() noexcept { return Impl().randbits(1); } bool randbool() noexcept { return Impl().template randbits<1>(); }
/** Return the time point advanced by a uniform random duration. */ /** Return the time point advanced by a uniform random duration. */
template <typename Tp> template <typename Tp>

View File

@ -1195,7 +1195,7 @@ BOOST_AUTO_TEST_CASE(muhash_tests)
uint256 res; uint256 res;
int table[4]; int table[4];
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
table[i] = g_insecure_rand_ctx.randbits(3); table[i] = g_insecure_rand_ctx.randbits<3>();
} }
for (int order = 0; order < 4; ++order) { for (int order = 0; order < 4; ++order) {
MuHash3072 acc; MuHash3072 acc;
@ -1215,8 +1215,8 @@ BOOST_AUTO_TEST_CASE(muhash_tests)
} }
} }
MuHash3072 x = FromInt(g_insecure_rand_ctx.randbits(4)); // x=X MuHash3072 x = FromInt(g_insecure_rand_ctx.randbits<4>()); // x=X
MuHash3072 y = FromInt(g_insecure_rand_ctx.randbits(4)); // x=X, y=Y MuHash3072 y = FromInt(g_insecure_rand_ctx.randbits<4>()); // x=X, y=Y
MuHash3072 z; // x=X, y=Y, z=1 MuHash3072 z; // x=X, y=Y, z=1
z *= x; // x=X, y=Y, z=X z *= x; // x=X, y=Y, z=X
z *= y; // x=X, y=Y, z=X*Y z *= y; // x=X, y=Y, z=X*Y

View File

@ -107,7 +107,7 @@ BOOST_AUTO_TEST_CASE(fastrandom_randbits)
BOOST_AUTO_TEST_CASE(randbits_test) BOOST_AUTO_TEST_CASE(randbits_test)
{ {
FastRandomContext ctx_lens; //!< RNG for producing the lengths requested from ctx_test. FastRandomContext ctx_lens; //!< RNG for producing the lengths requested from ctx_test.
FastRandomContext ctx_test; //!< The RNG being tested. FastRandomContext ctx_test1(true), ctx_test2(true); //!< The RNGs being tested.
int ctx_test_bitsleft{0}; //!< (Assumed value of) ctx_test::bitbuf_len int ctx_test_bitsleft{0}; //!< (Assumed value of) ctx_test::bitbuf_len
// Run the entire test 5 times. // Run the entire test 5 times.
@ -122,7 +122,25 @@ BOOST_AUTO_TEST_CASE(randbits_test)
// Decide on a number of bits to request (0 through 64, inclusive; don't use randbits/randrange). // Decide on a number of bits to request (0 through 64, inclusive; don't use randbits/randrange).
int bits = ctx_lens.rand64() % 65; int bits = ctx_lens.rand64() % 65;
// Generate that many bits. // Generate that many bits.
uint64_t gen = ctx_test.randbits(bits); uint64_t gen = ctx_test1.randbits(bits);
// For certain bits counts, also test randbits<Bits> and compare.
uint64_t gen2;
if (bits == 0) {
gen2 = ctx_test2.randbits<0>();
} else if (bits == 1) {
gen2 = ctx_test2.randbits<1>();
} else if (bits == 7) {
gen2 = ctx_test2.randbits<7>();
} else if (bits == 32) {
gen2 = ctx_test2.randbits<32>();
} else if (bits == 51) {
gen2 = ctx_test2.randbits<51>();
} else if (bits == 64) {
gen2 = ctx_test2.randbits<64>();
} else {
gen2 = ctx_test2.randbits(bits);
}
BOOST_CHECK_EQUAL(gen, gen2);
// Make sure the result is in range. // Make sure the result is in range.
if (bits < 64) BOOST_CHECK_EQUAL(gen >> bits, 0); if (bits < 64) BOOST_CHECK_EQUAL(gen >> bits, 0);
// Mark all the seen bits in the output. // Mark all the seen bits in the output.

View File

@ -77,3 +77,4 @@ shift-base:streams.h
shift-base:FormatHDKeypath shift-base:FormatHDKeypath
shift-base:xoroshiro128plusplus.h shift-base:xoroshiro128plusplus.h
shift-base:RandomMixin<*>::randbits shift-base:RandomMixin<*>::randbits
shift-base:RandomMixin<*>::randbits<*>