diff --git a/api_handlers.go b/api_handlers.go index ae5350b..5770965 100644 --- a/api_handlers.go +++ b/api_handlers.go @@ -1468,23 +1468,24 @@ func (s *APIServer) handleLoadWallet(w http.ResponseWriter, r *http.Request) { // Handle chain-ahead-of-wallet (chain was reset while wallet was offline) chainHeight := s.daemon.Chain().Height() - walletHeight := wl.SyncedHeight() - if walletHeight > chainHeight { - if removed := wl.RewindToHeight(chainHeight); removed > 0 { + originalWalletHeight := wl.SyncedHeight() + if removed := rewindWalletToCanonicalTip(wl, s.daemon.Chain()); removed > 0 { + if err := wl.Save(); err != nil { + writeInternal(w, r, http.StatusInternalServerError, "internal error", err) + return + } + } + + walletHeight, walletHash := wl.SyncedBlock() + if originalWalletHeight == chainHeight && walletHeight == chainHeight && chainHeight > 0 && !walletSyncHashKnown(walletHash) { + // Legacy wallets do not know which block hash they scanned at the tip. + // Rewind one block so the next scan records canonical hash metadata. + if removed := wl.RewindToHeight(chainHeight - 1); removed > 0 { if err := wl.Save(); err != nil { writeInternal(w, r, http.StatusInternalServerError, "internal error", err) return } } - walletHeight = wl.SyncedHeight() - } - - // Conservative reorg recovery: when wallet and chain heights match exactly, - // rewind one block and rescan tip. This clears stale same-height branch data - // even for wallets that predate tip-hash sync metadata. - if walletHeight == chainHeight && chainHeight > 0 { - wl.RewindToHeight(chainHeight - 1) - walletHeight = wl.SyncedHeight() } wl.SetInputFilter(func(out *wallet.OwnedOutput) bool { @@ -1589,7 +1590,11 @@ func (s *APIServer) handleCreateWallet(w http.ResponseWriter, r *http.Request) { chainHeight := s.daemon.Chain().Height() if chainHeight > 0 { - wl.SetSyncedHeight(chainHeight) + if block := s.daemon.Chain().GetBlockByHeight(chainHeight); block != nil { + wl.SetSyncedBlock(chainHeight, block.Hash()) + } else { + wl.SetSyncedHeight(chainHeight) + } if err := wl.Save(); err != nil { writeInternal(w, r, http.StatusInternalServerError, "internal error", err) return @@ -2347,6 +2352,12 @@ func (s *APIServer) catchUpScan() { ctx := s.cli.ctx + if removed := rewindWalletToCanonicalTip(w, s.daemon.Chain()); removed > 0 { + if err := w.Save(); err != nil { + log.Printf("Warning: catchUpScan reorg save: %v", err) + } + } + walletHeight := w.SyncedHeight() chainHeight := s.daemon.Chain().Height() if walletHeight >= chainHeight { @@ -2364,11 +2375,25 @@ func (s *APIServer) catchUpScan() { if block == nil { break } + _, walletHash := w.SyncedBlock() + if walletSyncHashKnown(walletHash) && block.Header.PrevHash != walletHash { + if h <= 1 { + break + } + if removed := w.RewindToHeight(h - 2); removed > 0 { + if err := w.Save(); err != nil { + log.Printf("Warning: catchUpScan reorg save at height %d: %v", h, err) + } + } + h = w.SyncedHeight() + continue + } sc.ScanBlock(blockToScanData(block)) w.ReconcileUnconfirmedSpends(func(txID [32]byte) bool { return s.daemon.Mempool().HasTransaction(txID) }) - w.SetSyncedHeight(h) + blockHash := block.Hash() + w.SetSyncedBlock(h, blockHash) if h%100 == 0 || h == chainHeight { if err := w.Save(); err != nil { diff --git a/cli.go b/cli.go index c808154..38b93f8 100644 --- a/cli.go +++ b/cli.go @@ -490,18 +490,19 @@ func (c *CLI) recoverWalletAfterChainReset() { return } chainHeight := c.daemon.Chain().Height() - walletHeight := c.wallet.SyncedHeight() - if walletHeight > chainHeight { - removed := c.wallet.RewindToHeight(chainHeight) - if removed > 0 { - fmt.Printf(" Chain reset: removed %d orphaned outputs, rewound to height %d\n", removed, chainHeight) - if err := c.wallet.Save(); err != nil { - fmt.Printf(" Warning: failed to persist rewound wallet: %v\n", err) - } + originalWalletHeight := c.wallet.SyncedHeight() + removed := rewindWalletToCanonicalTip(c.wallet, c.daemon.Chain()) + if removed > 0 { + fmt.Printf(" Chain reset: removed %d orphaned outputs, rewound to height %d\n", removed, c.wallet.SyncedHeight()) + if err := c.wallet.Save(); err != nil { + fmt.Printf(" Warning: failed to persist rewound wallet: %v\n", err) } - } else if chainHeight > 0 && walletHeight == chainHeight { - // Conservative reorg recovery: if wallet and chain are at the same height, - // rewind one block to clear potentially stale same-height fork state. + } + + walletHeight, walletHash := c.wallet.SyncedBlock() + if chainHeight > 0 && originalWalletHeight == chainHeight && walletHeight == chainHeight && !walletSyncHashKnown(walletHash) { + // Legacy wallets do not know which block hash they scanned at the tip. + // Rewind one block so the next scan records canonical hash metadata. removed := c.wallet.RewindToHeight(chainHeight - 1) if removed > 0 { fmt.Printf(" Chain reset: removed %d orphaned outputs, rewound to height %d\n", removed, chainHeight-1) @@ -1116,20 +1117,47 @@ func (c *CLI) autoScanBlocks() { w := c.wallet c.mu.RUnlock() - if scanner == nil { + if scanner == nil || w == nil { continue } - blockData := blockToScanData(block) - scanner.ScanBlock(blockData) - w.ReconcileUnconfirmedSpends(func(txID [32]byte) bool { - return c.daemon.Mempool().HasTransaction(txID) - }) + if removed := rewindWalletToCanonicalTip(w, c.daemon.Chain()); removed > 0 { + unsaved++ + } - w.SetSyncedHeight(blockData.Height) - unsaved++ - if unsaved >= saveBatchSize { - doSave() + for { + walletHeight, walletHash := w.SyncedBlock() + chainHeight := c.daemon.Chain().Height() + if walletHeight >= chainHeight { + break + } + + nextHeight := walletHeight + 1 + canonical := c.daemon.Chain().GetBlockByHeight(nextHeight) + if canonical == nil { + break + } + if walletSyncHashKnown(walletHash) && canonical.Header.PrevHash != walletHash { + if walletHeight == 0 { + break + } + if removed := w.RewindToHeight(walletHeight - 1); removed > 0 { + unsaved++ + } + continue + } + + blockData := blockToScanData(canonical) + scanner.ScanBlock(blockData) + w.ReconcileUnconfirmedSpends(func(txID [32]byte) bool { + return c.daemon.Mempool().HasTransaction(txID) + }) + + w.SetSyncedBlock(blockData.Height, blockData.Hash) + unsaved++ + if unsaved >= saveBatchSize { + doSave() + } } } } @@ -1159,6 +1187,7 @@ func (c *CLI) watchMinedBlocks() { func blockToScanData(block *Block) *wallet.BlockData { data := &wallet.BlockData{ Height: block.Header.Height, + Hash: block.Hash(), Transactions: make([]wallet.TxData, len(block.Transactions)), } diff --git a/cli_cmd_wallet.go b/cli_cmd_wallet.go index 7a7bd34..939f3d0 100644 --- a/cli_cmd_wallet.go +++ b/cli_cmd_wallet.go @@ -711,6 +711,12 @@ func memoTextIfPrintable(b []byte) (string, bool) { } func (c *CLI) cmdSync() { + if removed := rewindWalletToCanonicalTip(c.wallet, c.daemon.Chain()); removed > 0 { + fmt.Printf(" Chain reset: removed %d orphaned outputs, rewound to height %d\n", removed, c.wallet.SyncedHeight()) + if err := c.wallet.Save(); err != nil { + fmt.Printf(" Warning: failed to persist rewound wallet: %v\n", err) + } + } chainHeight := c.daemon.Chain().Height() walletHeight := c.wallet.SyncedHeight() @@ -729,6 +735,7 @@ func (c *CLI) cmdSync() { blocks := c.daemon.Chain().GetBlocksByHeightRange(walletHeight+1, chainHeight) scannedTo := walletHeight + var scannedHash [32]byte for _, block := range blocks { if block == nil { break @@ -739,13 +746,14 @@ func (c *CLI) cmdSync() { h := block.Header.Height scannedTo = h + scannedHash = blockData.Hash if found > 0 || spent > 0 { fmt.Printf(" Block %d: +%d outputs, %d spent\n", h, found, spent) } } if scannedTo > walletHeight { - c.wallet.SetSyncedHeight(scannedTo) + c.wallet.SetSyncedBlock(scannedTo, scannedHash) fmt.Printf(" Wallet synced to height %d\n", scannedTo) } } diff --git a/daemon.go b/daemon.go index 6d65164..c815872 100644 --- a/daemon.go +++ b/daemon.go @@ -1250,8 +1250,12 @@ func (d *Daemon) SubmitBlock(block *Block) error { } d.syncMgr.BroadcastBlock(blockData) - // Notify subscribers - d.notifyBlock(block) + // Notify wallet/UI subscribers only for blocks that became main-chain. + // Side-chain blocks may be accepted for fork choice, but scanning them would + // record outputs the wallet cannot spend on the canonical chain. + if isMainChain { + d.notifyBlock(block) + } d.notifyMinedBlock(block) return nil diff --git a/wallet/scanner.go b/wallet/scanner.go index 1823b6a..119115d 100644 --- a/wallet/scanner.go +++ b/wallet/scanner.go @@ -10,6 +10,7 @@ import ( // BlockData is the minimal block info needed for scanning type BlockData struct { Height uint64 + Hash [32]byte Transactions []TxData } @@ -165,6 +166,7 @@ func (s *Scanner) ScanBlock(block *BlockData) (found int, spent int) { OneTimePubKey: out.PubKey, Commitment: out.Commitment, BlockHeight: block.Height, + BlockHash: block.Hash, IsCoinbase: tx.IsCoinbase, Spent: false, } @@ -218,7 +220,7 @@ func (s *Scanner) ScanBlocks(blocks []*BlockData) (totalFound, totalSpent int) { totalFound += found totalSpent += spent - s.wallet.SetSyncedHeight(block.Height) + s.wallet.SetSyncedBlock(block.Height, block.Hash) } return totalFound, totalSpent } @@ -268,9 +270,9 @@ func BlockToScanData(blockJSON []byte) (*BlockData, error) { TxID [32]byte `json:"tx_id"` TxPublicKey [32]byte `json:"tx_public_key"` Outputs []struct { - PublicKey [32]byte `json:"public_key"` - Commitment [32]byte `json:"commitment"` - EncryptedAmount [8]byte `json:"encrypted_amount"` + PublicKey [32]byte `json:"public_key"` + Commitment [32]byte `json:"commitment"` + EncryptedAmount [8]byte `json:"encrypted_amount"` EncryptedMemo [MemoSize]byte `json:"encrypted_memo"` } `json:"outputs"` Inputs []struct { diff --git a/wallet/sync_reorg_test.go b/wallet/sync_reorg_test.go new file mode 100644 index 0000000..b1eb543 --- /dev/null +++ b/wallet/sync_reorg_test.go @@ -0,0 +1,35 @@ +package wallet + +import "testing" + +func TestWalletRewindRestoresSyncedBlockHash(t *testing.T) { + w := &Wallet{} + hash1 := [32]byte{0x01} + hash2 := [32]byte{0x02} + + w.SetSyncedBlock(1, hash1) + w.SetSyncedBlock(2, hash2) + + w.RewindToHeight(1) + + height, hash := w.SyncedBlock() + if height != 1 { + t.Fatalf("height=%d, want 1", height) + } + if hash != hash1 { + t.Fatalf("hash=%x, want %x", hash[:], hash1[:]) + } +} + +func TestWalletSetSyncedHeightClearsUnknownHash(t *testing.T) { + w := &Wallet{} + hash := [32]byte{0x01} + + w.SetSyncedBlock(1, hash) + w.SetSyncedHeight(1) + + _, got := w.SyncedBlock() + if got != ([32]byte{}) { + t.Fatalf("hash=%x, want zero hash", got[:]) + } +} diff --git a/wallet/wallet.go b/wallet/wallet.go index fe1f2fb..d0ab33b 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -265,6 +265,7 @@ type OwnedOutput struct { OneTimePubKey [32]byte `json:"one_time_pub"` Commitment [32]byte `json:"commitment"` BlockHeight uint64 `json:"block_height"` + BlockHash [32]byte `json:"block_hash,omitempty"` IsCoinbase bool `json:"is_coinbase"` // True if from mining reward Spent bool `json:"spent"` SpentHeight uint64 `json:"spent_height,omitempty"` @@ -331,6 +332,8 @@ type WalletData struct { SendHistory []*SendRecord `json:"send_history,omitempty"` // Track outgoing transactions PendingCredits []*PendingCredit `json:"pending_credits,omitempty"` // UX-only pending credits (e.g. unconfirmed change) SyncedHeight uint64 `json:"synced_height"` + SyncedHash [32]byte `json:"synced_hash,omitempty"` + SyncedBlocks []SyncedBlock `json:"synced_blocks,omitempty"` CreatedAt int64 `json:"created_at"` } @@ -410,6 +413,12 @@ type PendingCredit struct { AddedAt int64 `json:"added_at"` } +// SyncedBlock records the canonical block hash that a wallet scanned at a height. +type SyncedBlock struct { + Height uint64 `json:"height"` + Hash [32]byte `json:"hash"` +} + type reservedOutpoint struct { TxID [32]byte OutputIndex int @@ -1313,11 +1322,29 @@ func (w *Wallet) SyncedHeight() uint64 { return w.data.SyncedHeight } -// SetSyncedHeight updates the sync height +// SyncedBlock returns the last canonical block scanned by the wallet. +func (w *Wallet) SyncedBlock() (uint64, [32]byte) { + w.mu.RLock() + defer w.mu.RUnlock() + return w.data.SyncedHeight, w.data.SyncedHash +} + +// SetSyncedHeight updates the sync height without block-hash metadata. func (w *Wallet) SetSyncedHeight(height uint64) { w.mu.Lock() defer w.mu.Unlock() w.data.SyncedHeight = height + w.data.SyncedHash = [32]byte{} + w.pruneSyncedBlocksLocked(height) +} + +// SetSyncedBlock updates the wallet sync point to a canonical block hash. +func (w *Wallet) SetSyncedBlock(height uint64, hash [32]byte) { + w.mu.Lock() + defer w.mu.Unlock() + w.data.SyncedHeight = height + w.data.SyncedHash = hash + w.recordSyncedBlockLocked(height, hash) } // RewindToHeight removes outputs from blocks above the given height @@ -1344,9 +1371,54 @@ func (w *Wallet) RewindToHeight(height uint64) int { if w.data.SyncedHeight > height { w.data.SyncedHeight = height } + w.pruneSyncedBlocksLocked(height) + w.data.SyncedHash = w.syncedHashAtLocked(w.data.SyncedHeight) return removed } +func (w *Wallet) recordSyncedBlockLocked(height uint64, hash [32]byte) { + if hash == ([32]byte{}) { + w.pruneSyncedBlocksLocked(height) + return + } + for i := range w.data.SyncedBlocks { + if w.data.SyncedBlocks[i].Height == height { + w.data.SyncedBlocks[i].Hash = hash + w.pruneSyncedBlocksLocked(height) + return + } + } + w.data.SyncedBlocks = append(w.data.SyncedBlocks, SyncedBlock{Height: height, Hash: hash}) + w.pruneSyncedBlocksLocked(height) +} + +func (w *Wallet) pruneSyncedBlocksLocked(maxHeight uint64) { + if len(w.data.SyncedBlocks) == 0 { + return + } + kept := w.data.SyncedBlocks[:0] + for _, block := range w.data.SyncedBlocks { + if block.Height <= maxHeight { + kept = append(kept, block) + } + } + if len(kept) == 0 { + w.data.SyncedBlocks = nil + return + } + w.data.SyncedBlocks = kept +} + +func (w *Wallet) syncedHashAtLocked(height uint64) [32]byte { + for i := len(w.data.SyncedBlocks) - 1; i >= 0; i-- { + block := w.data.SyncedBlocks[i] + if block.Height == height { + return block.Hash + } + } + return [32]byte{} +} + // OutputCount returns total output count func (w *Wallet) OutputCount() (total, unspent int) { w.mu.RLock() diff --git a/wallet_reorg_prevention_test.go b/wallet_reorg_prevention_test.go new file mode 100644 index 0000000..aad3bbc --- /dev/null +++ b/wallet_reorg_prevention_test.go @@ -0,0 +1,90 @@ +package main + +import ( + "testing" + + "blocknet/wallet" +) + +func TestRewindWalletToCanonicalTipRewindsKnownForkedTip(t *testing.T) { + chain, storage, cleanup := mustCreateTestChain(t) + defer cleanup() + mustAddGenesisBlock(t, chain) + + genesis := chain.GetBlockByHeight(0) + if genesis == nil { + t.Fatal("expected genesis block") + } + + forked := makeOutputOnlyTestBlock( + 1, + genesis.Hash(), + genesis.Header.Timestamp+BlockIntervalSec, + nil, + ) + canonical := makeOutputOnlyTestBlock( + 1, + genesis.Hash(), + genesis.Header.Timestamp+BlockIntervalSec+1, + nil, + ) + commitMainChainBlockForTest(t, chain, storage, canonical, MinDifficulty) + + w, err := wallet.NewWallet(t.TempDir()+"/wallet.dat", []byte("correct-password"), defaultWalletConfig()) + if err != nil { + t.Fatalf("NewWallet: %v", err) + } + forkedHash := forked.Hash() + w.AddOutput(&wallet.OwnedOutput{ + TxID: [32]byte{0x01}, + OutputIndex: 0, + Amount: 500, + BlockHeight: 1, + BlockHash: forkedHash, + }) + w.SetSyncedBlock(1, forkedHash) + + removed := rewindWalletToCanonicalTip(w, chain) + if removed != 1 { + t.Fatalf("removed=%d, want 1", removed) + } + if got := w.SyncedHeight(); got != 0 { + t.Fatalf("SyncedHeight=%d, want 0", got) + } + total, unspent := w.OutputCount() + if total != 0 || unspent != 0 { + t.Fatalf("OutputCount total=%d unspent=%d, want 0/0", total, unspent) + } +} + +func TestRewindWalletToCanonicalTipKeepsMatchingTip(t *testing.T) { + chain, storage, cleanup := mustCreateTestChain(t) + defer cleanup() + mustAddGenesisBlock(t, chain) + + genesis := chain.GetBlockByHeight(0) + if genesis == nil { + t.Fatal("expected genesis block") + } + canonical := makeOutputOnlyTestBlock( + 1, + genesis.Hash(), + genesis.Header.Timestamp+BlockIntervalSec, + nil, + ) + commitMainChainBlockForTest(t, chain, storage, canonical, MinDifficulty) + + w, err := wallet.NewWallet(t.TempDir()+"/wallet.dat", []byte("correct-password"), defaultWalletConfig()) + if err != nil { + t.Fatalf("NewWallet: %v", err) + } + canonicalHash := canonical.Hash() + w.SetSyncedBlock(1, canonicalHash) + + if removed := rewindWalletToCanonicalTip(w, chain); removed != 0 { + t.Fatalf("removed=%d, want 0", removed) + } + if got := w.SyncedHeight(); got != 1 { + t.Fatalf("SyncedHeight=%d, want 1", got) + } +} diff --git a/wallet_sync.go b/wallet_sync.go new file mode 100644 index 0000000..872401e --- /dev/null +++ b/wallet_sync.go @@ -0,0 +1,39 @@ +package main + +import "blocknet/wallet" + +func walletSyncHashKnown(hash [32]byte) bool { + return hash != ([32]byte{}) +} + +func rewindWalletToCanonicalTip(w *wallet.Wallet, chain *Chain) int { + if w == nil || chain == nil { + return 0 + } + + removed := 0 + for { + chainHeight := chain.Height() + walletHeight, walletHash := w.SyncedBlock() + if walletHeight == 0 { + return removed + } + if walletHeight > chainHeight { + removed += w.RewindToHeight(chainHeight) + continue + } + if !walletSyncHashKnown(walletHash) { + return removed + } + + block := chain.GetBlockByHeight(walletHeight) + if block == nil { + removed += w.RewindToHeight(walletHeight - 1) + continue + } + if block.Hash() == walletHash { + return removed + } + removed += w.RewindToHeight(walletHeight - 1) + } +}