diff --git a/src/rpc/client.cpp b/src/rpc/client.cpp index 3267d1d2ea..b1dfdab49f 100644 --- a/src/rpc/client.cpp +++ b/src/rpc/client.cpp @@ -146,6 +146,7 @@ static const CRPCConvertParam vRPCConvertParams[] = { "fundrawtransaction", 1, "lockUnspents"}, { "fundrawtransaction", 1, "fee_rate"}, { "fundrawtransaction", 1, "feeRate"}, + { "fundrawtransaction", 1, "segwit_inputs_only"}, { "fundrawtransaction", 1, "subtractFeeFromOutputs"}, { "fundrawtransaction", 1, "input_weights"}, { "fundrawtransaction", 1, "conf_target"}, diff --git a/src/wallet/coincontrol.h b/src/wallet/coincontrol.h index 71593e236f..c6bc878465 100644 --- a/src/wallet/coincontrol.h +++ b/src/wallet/coincontrol.h @@ -34,6 +34,8 @@ public: std::optional m_change_type; //! If false, only safe inputs will be used bool m_include_unsafe_inputs = false; + //! If true, only segwit inputs are selected + bool m_segwit_inputs_only = false; //! If true, the selection process can add extra unselected inputs from the wallet //! while requires all selected inputs be used bool m_allow_other_inputs = true; diff --git a/src/wallet/rpc/spend.cpp b/src/wallet/rpc/spend.cpp index 9bed876988..a4f270ce6f 100644 --- a/src/wallet/rpc/spend.cpp +++ b/src/wallet/rpc/spend.cpp @@ -626,6 +626,7 @@ void FundTransaction(CWallet& wallet, CMutableTransaction& tx, CAmount& fee_out, {"minconf", UniValueType(UniValue::VNUM)}, {"maxconf", UniValueType(UniValue::VNUM)}, {"input_weights", UniValueType(UniValue::VARR)}, + {"segwit_inputs_only", UniValueType(UniValue::VBOOL)}, }, true, true); @@ -633,6 +634,10 @@ void FundTransaction(CWallet& wallet, CMutableTransaction& tx, CAmount& fee_out, coinControl.m_allow_other_inputs = options["add_inputs"].get_bool(); } + if (options.exists("segwit_inputs_only")) { + coinControl.m_segwit_inputs_only = options["segwit_inputs_only"].get_bool(); + } + if (options.exists("changeAddress") || options.exists("change_address")) { const std::string change_address_str = (options.exists("change_address") ? options["change_address"] : options["changeAddress"]).get_str(); CTxDestination dest = DecodeDestination(change_address_str); @@ -891,6 +896,7 @@ RPCHelpMan fundrawtransaction() }, }, }, + {"segwit_inputs_only", RPCArg::Type::BOOL, RPCArg::Default{false}, "Whether to only use segwit inputs for transaction."}, }, FundTxDoc()), RPCArgOptions{ diff --git a/src/wallet/spend.cpp b/src/wallet/spend.cpp index 32df3c8c0e..5e192037e8 100644 --- a/src/wallet/spend.cpp +++ b/src/wallet/spend.cpp @@ -315,6 +315,7 @@ CoinsResult AvailableCoins(const CWallet& wallet, const int min_depth = {coinControl ? coinControl->m_min_depth : DEFAULT_MIN_DEPTH}; const int max_depth = {coinControl ? coinControl->m_max_depth : DEFAULT_MAX_DEPTH}; const bool only_safe = {coinControl ? !coinControl->m_include_unsafe_inputs : true}; + const bool segwit_inputs_only = {coinControl ? coinControl->m_segwit_inputs_only : false}; const bool can_grind_r = wallet.CanGrindR(); std::vector outpoints; @@ -408,6 +409,10 @@ CoinsResult AvailableCoins(const CWallet& wallet, std::unique_ptr provider = wallet.GetSolvingProvider(output.scriptPubKey); + if (segwit_inputs_only && !IsSegWitOutput(*provider, wtx.tx->vout[i].scriptPubKey)) { + continue; + } + int input_bytes = CalculateMaximumSignedInputSize(output, COutPoint(), provider.get(), can_grind_r, coinControl); // Because CalculateMaximumSignedInputSize infers a solvable descriptor to get the satisfaction size, // it is safe to assume that this input is solvable if input_bytes is greater than -1. diff --git a/test/functional/wallet_fundrawtransaction.py b/test/functional/wallet_fundrawtransaction.py index 77611649ac..4e20e4f575 100755 --- a/test/functional/wallet_fundrawtransaction.py +++ b/test/functional/wallet_fundrawtransaction.py @@ -8,6 +8,7 @@ from decimal import Decimal from itertools import product from math import ceil +from test_framework.blocktools import COINBASE_MATURITY from test_framework.descriptors import descsum_create from test_framework.messages import ( @@ -136,6 +137,7 @@ class RawTransactionsTest(BitcoinTestFramework): self.test_locked_wallet() self.test_many_inputs_fee() self.test_many_inputs_send() + self.test_witness_only() self.test_op_return() self.test_watchonly() self.test_all_watched_funds() @@ -188,6 +190,52 @@ class RawTransactionsTest(BitcoinTestFramework): dec_tx = self.nodes[2].decoderawtransaction(rawtxfund['hex']) assert len(dec_tx['vin']) > 0 #test that we have enough inputs + def check_witness_inputs(self, vins): + for vin in vins: + # check vin is a segwit input + utxo = self.nodes[2].gettxout(vin['txid'], vin['vout']) + info = self.nodes[2].getaddressinfo(utxo['scriptPubKey']['address']) + if not (info['iswitness'] or info['embedded']['iswitness']): + return False + + return True + + def test_witness_only(self): + self.log.info("Testing fundrawtxn with witness inputs only") + + self.generate(self.nodes[0], COINBASE_MATURITY + 10) + self.nodes[2].sendall(recipients=[self.nodes[0].getnewaddress()]) + + output_types = ['legacy', 'p2sh-segwit', 'bech32'] + if self.options.descriptors: + output_types.append('bech32m') + # Create coins + for _ in range(10): + for output_type in output_types: + self.nodes[0].sendtoaddress(self.nodes[2].getnewaddress(address_type=output_type), 1) + + self.generate(self.nodes[0], 1) + + inputs = [ ] + target_addr = self.nodes[2].getnewaddress() + segwit_balance = (len(output_types) - 1) * 10 + + # make sure legacy inputs are not accepted in witness only mode if no witness inputs are found + # trying to spend more than segwit total should fail + outputs = { target_addr : segwit_balance + Decimal('0.00000001') } + rawtx = self.nodes[2].createrawtransaction(inputs, outputs) + assert_raises_rpc_error(-4, "Insufficient funds", self.nodes[2].fundrawtransaction, rawtx, {'segwit_inputs_only': True, 'subtractFeeFromOutputs': [0]}) + + # make sure all inputs are of type witness + outputs = { target_addr : segwit_balance } + rawtx = self.nodes[2].createrawtransaction(inputs, outputs) + rawtxfund = self.nodes[2].fundrawtransaction(rawtx, {'segwit_inputs_only': True, 'subtractFeeFromOutputs': [0]}) + dec_tx = self.nodes[2].decoderawtransaction(rawtxfund['hex']) + + assert len(dec_tx['vin']) > 0 + assert(self.check_witness_inputs(dec_tx['vin'])) + + def test_simple_two_coins(self): self.log.info("Test fundrawtxn with 2 coins") inputs = [ ]