From 25a3b9b0f52e61e0189d6e7e727a0ffd2b1e39fa Mon Sep 17 00:00:00 2001 From: Ava Chow Date: Mon, 22 Jan 2024 17:07:50 -0500 Subject: [PATCH] descriptors: Have GetPubKey fill origins directly Instead of having ExpandHelper fill in the origins in the FlatSigningProvider output, have GetPubKey do it by itself. This reduces the extra variables needed in order to track and set origins in ExpandHelper. Also changes GetPubKey to return a std::optional rather than using a bool and output parameters. --- src/script/descriptor.cpp | 72 ++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/src/script/descriptor.cpp b/src/script/descriptor.cpp index a0caf4149c..3abd85dd7f 100644 --- a/src/script/descriptor.cpp +++ b/src/script/descriptor.cpp @@ -174,22 +174,20 @@ public: * Used by the Miniscript descriptors to check for duplicate keys in the script. */ bool operator<(PubkeyProvider& other) const { - CPubKey a, b; - SigningProvider dummy; - KeyOriginInfo dummy_info; + FlatSigningProvider dummy; - GetPubKey(0, dummy, a, dummy_info); - other.GetPubKey(0, dummy, b, dummy_info); + std::optional a = GetPubKey(0, dummy, dummy); + std::optional b = other.GetPubKey(0, dummy, dummy); return a < b; } - /** Derive a public key. + /** Derive a public key and put it into out. * read_cache is the cache to read keys from (if not nullptr) * write_cache is the cache to write keys to (if not nullptr) * Caches are not exclusive but this is not tested. Currently we use them exclusively */ - virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const = 0; + virtual std::optional GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const = 0; /** Whether this represent multiple public keys at different positions. */ virtual bool IsRange() const = 0; @@ -240,12 +238,15 @@ class OriginPubkeyProvider final : public PubkeyProvider public: OriginPubkeyProvider(uint32_t exp_index, KeyOriginInfo info, std::unique_ptr provider, bool apostrophe) : PubkeyProvider(exp_index), m_origin(std::move(info)), m_provider(std::move(provider)), m_apostrophe(apostrophe) {} - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override + std::optional GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { - if (!m_provider->GetPubKey(pos, arg, key, info, read_cache, write_cache)) return false; - std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), info.fingerprint); - info.path.insert(info.path.begin(), m_origin.path.begin(), m_origin.path.end()); - return true; + std::optional pub = m_provider->GetPubKey(pos, arg, out, read_cache, write_cache); + if (!pub) return std::nullopt; + auto& [pubkey, suborigin] = out.origins[pub->GetID()]; + Assert(pubkey == *pub); // m_provider must have a valid origin by this point. + std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), suborigin.fingerprint); + suborigin.path.insert(suborigin.path.begin(), m_origin.path.begin(), m_origin.path.end()); + return pub; } bool IsRange() const override { return m_provider->IsRange(); } size_t GetSize() const override { return m_provider->GetSize(); } @@ -298,13 +299,13 @@ class ConstPubkeyProvider final : public PubkeyProvider public: ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey, bool xonly) : PubkeyProvider(exp_index), m_pubkey(pubkey), m_xonly(xonly) {} - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override + std::optional GetPubKey(int pos, const SigningProvider&, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { - key = m_pubkey; - info.path.clear(); + KeyOriginInfo info; CKeyID keyid = m_pubkey.GetID(); std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), info.fingerprint); - return true; + out.origins.emplace(keyid, std::make_pair(m_pubkey, info)); + return m_pubkey; } bool IsRange() const override { return false; } size_t GetSize() const override { return m_pubkey.size(); } @@ -394,7 +395,7 @@ public: BIP32PubkeyProvider(uint32_t exp_index, const CExtPubKey& extkey, KeyPath path, DeriveType derive, bool apostrophe) : PubkeyProvider(exp_index), m_root_extkey(extkey), m_path(std::move(path)), m_derive(derive), m_apostrophe(apostrophe) {} bool IsRange() const override { return m_derive != DeriveType::NO; } size_t GetSize() const override { return 33; } - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key_out, KeyOriginInfo& final_info_out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override + std::optional GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { KeyOriginInfo info; CKeyID keyid = m_root_extkey.pubkey.GetID(); @@ -410,16 +411,16 @@ public: bool der = true; if (read_cache) { if (!read_cache->GetCachedDerivedExtPubKey(m_expr_index, pos, final_extkey)) { - if (m_derive == DeriveType::HARDENED) return false; + if (m_derive == DeriveType::HARDENED) return std::nullopt; // Try to get the derivation parent - if (!read_cache->GetCachedParentExtPubKey(m_expr_index, parent_extkey)) return false; + if (!read_cache->GetCachedParentExtPubKey(m_expr_index, parent_extkey)) return std::nullopt; final_extkey = parent_extkey; if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); } } else if (IsHardened()) { CExtKey xprv; CExtKey lh_xprv; - if (!GetDerivedExtKey(arg, xprv, lh_xprv)) return false; + if (!GetDerivedExtKey(arg, xprv, lh_xprv)) return std::nullopt; parent_extkey = xprv.Neuter(); if (m_derive == DeriveType::UNHARDENED) der = xprv.Derive(xprv, pos); if (m_derive == DeriveType::HARDENED) der = xprv.Derive(xprv, pos | 0x80000000UL); @@ -429,16 +430,15 @@ public: } } else { for (auto entry : m_path) { - if (!parent_extkey.Derive(parent_extkey, entry)) return false; + if (!parent_extkey.Derive(parent_extkey, entry)) return std::nullopt; } final_extkey = parent_extkey; if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); assert(m_derive != DeriveType::HARDENED); } - if (!der) return false; + if (!der) return std::nullopt; - final_info_out = info; - key_out = final_extkey.pubkey; + out.origins.emplace(final_extkey.pubkey.GetID(), std::make_pair(final_extkey.pubkey, info)); if (write_cache) { // Only cache parent if there is any unhardened derivation @@ -448,12 +448,12 @@ public: if (last_hardened_extkey.pubkey.IsValid()) { write_cache->CacheLastHardenedExtPubKey(m_expr_index, last_hardened_extkey); } - } else if (final_info_out.path.size() > 0) { + } else if (info.path.size() > 0) { write_cache->CacheDerivedExtPubKey(m_expr_index, pos, final_extkey); } } - return true; + return final_extkey.pubkey; } std::string ToString(StringType type, bool normalized) const { @@ -696,16 +696,17 @@ public: // NOLINTNEXTLINE(misc-no-recursion) bool ExpandHelper(int pos, const SigningProvider& arg, const DescriptorCache* read_cache, std::vector& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache) const { - std::vector> entries; - entries.reserve(m_pubkey_args.size()); + FlatSigningProvider subprovider; + std::vector pubkeys; + pubkeys.reserve(m_pubkey_args.size()); - // Construct temporary data in `entries`, `subscripts`, and `subprovider` to avoid producing output in case of failure. + // Construct temporary data in `pubkeys`, `subscripts`, and `subprovider` to avoid producing output in case of failure. for (const auto& p : m_pubkey_args) { - entries.emplace_back(); - if (!p->GetPubKey(pos, arg, entries.back().first, entries.back().second, read_cache, write_cache)) return false; + std::optional pubkey = p->GetPubKey(pos, arg, subprovider, read_cache, write_cache); + if (!pubkey) return false; + pubkeys.push_back(pubkey.value()); } std::vector subscripts; - FlatSigningProvider subprovider; for (const auto& subarg : m_subdescriptor_args) { std::vector outscripts; if (!subarg->ExpandHelper(pos, arg, read_cache, outscripts, subprovider, write_cache)) return false; @@ -714,13 +715,6 @@ public: } out.Merge(std::move(subprovider)); - std::vector pubkeys; - pubkeys.reserve(entries.size()); - for (auto& entry : entries) { - pubkeys.push_back(entry.first); - out.origins.emplace(entry.first.GetID(), std::make_pair(CPubKey(entry.first), std::move(entry.second))); - } - output_scripts = MakeScripts(pubkeys, std::span{subscripts}, out); return true; }