{-# OPTIONS_GHC -Wno-orphans #-}

-- | This module provides an automated attack to try and perform double
-- satisfaction on a contract.
module Cooked.Attack.DoubleSat
  ( DoubleSatDelta,
    DoubleSatLbl (..),
    doubleSatAttack,
  )
where

import Cooked.MockChain.BlockChain
import Cooked.Output
import Cooked.Pretty
import Cooked.Skeleton
import Cooked.Tweak
import Cooked.Wallet
import Data.Map (Map)
import Data.Map qualified as Map
import Optics.Core
import Plutus.Script.Utils.Value qualified as Script
import PlutusLedgerApi.V3 qualified as Api
import PlutusTx.Numeric qualified as PlutusTx

{- Note: What is a double satisfaction attack?

A double satisfaction attack consists in trying to satisfy the
requirements for what conceptually are two transactions in a single
transaction, and doing so incompletely. It succeeds whenever the
requirements of two validators ovelap, but the required outputs of the
transaction are not sufficiently unique, so that both validators see
them as satisfying "their" requirement.

The mechanism is explained very well in the following analogy from the
Plutus documentation: "Suppose that two tax auditors from two
different departments come to visit you in turn to see if you’ve paid
your taxes. You come up with a clever scheme to confuse them. Your tax
liability to both departments is $10, so you make a single payment to
the tax office’s bank account for $10. When the auditors arrive, you
show them your books, containing the payment to the tax office. They
both leave satisfied."

The double satisfaction attack 'doubleSatAttack' provided by this
module works by going through the foci of some optic on the 'TxSkel'
representing the transaction from the left to the right, and adding
some extra inputs, outputs, and mints depending on each focus and the
current 'MockChainSt'ate. -}

-- | A triple of transaction inputs, transaction outputs, and minted
-- value. This is what we can add to the transaction in order to try a
-- double satisfaction attack.
type DoubleSatDelta = (Map Api.TxOutRef TxSkelRedeemer, [TxSkelOut], TxSkelMints)

instance {-# OVERLAPPING #-} Semigroup DoubleSatDelta where
  (Map TxOutRef TxSkelRedeemer
i, [TxSkelOut]
o, TxSkelMints
m) <> :: DoubleSatDelta -> DoubleSatDelta -> DoubleSatDelta
<> (Map TxOutRef TxSkelRedeemer
i', [TxSkelOut]
o', TxSkelMints
m') =
    ( Map TxOutRef TxSkelRedeemer
i Map TxOutRef TxSkelRedeemer
-> Map TxOutRef TxSkelRedeemer -> Map TxOutRef TxSkelRedeemer
forall a. Semigroup a => a -> a -> a
<> Map TxOutRef TxSkelRedeemer
i', -- this is left-biased union
      [TxSkelOut]
o [TxSkelOut] -> [TxSkelOut] -> [TxSkelOut]
forall a. [a] -> [a] -> [a]
++ [TxSkelOut]
o',
      TxSkelMints
m TxSkelMints -> TxSkelMints -> TxSkelMints
forall a. Semigroup a => a -> a -> a
<> TxSkelMints
m' -- see the 'Semigroup' instance of 'TxSkelMints'
    )

instance {-# OVERLAPPING #-} Monoid DoubleSatDelta where
  mempty :: DoubleSatDelta
mempty = (Map TxOutRef TxSkelRedeemer
forall k a. Map k a
Map.empty, [], TxSkelMints
forall a. Monoid a => a
mempty)

-- | Double satisfaction attack. See the comment above for what such an
-- attack is about conceptually.
--
-- This attack consists in adding some extra constraints to a
-- transaction, and hoping that the additional minting policies or
-- validator scripts thereby involved are fooled by what's already
-- present on the transaction. Any extra value contained in new inputs
-- to the transaction is then paid to the attacker.
doubleSatAttack ::
  (MonadTweak m, Eq is, Is k A_Traversal) =>
  -- | how to combine modifications from caused by different foci. See
  -- the comment at 'combineModsTweak', which uses the same logic.
  ([is] -> [[is]]) ->
  -- | Each focus of this optic is a potential reason to add some
  -- extra constraints.
  --
  -- As an example, one could go through the 'paysScript' outputs for
  -- validators of type @t@ with the following traversal:
  --
  -- > txSkelOutsL % itaversed % txSkelOutputToTypedValidatorP @t
  Optic' k (WithIx is) TxSkel a ->
  -- | How to change each focus, and which inputs, outputs, and mints
  -- to add, for each of the foci. There might be different options
  -- for each focus, that's why the return value is a list.
  --
  -- Continuing the example, for each of the focused 'paysScript'
  -- outputs, you might want to try adding some 'spendsScript' inputs
  -- to the transaction. Since it might be interesting to try
  -- different redeemers on these extra 'spendsScript' inputs, you can
  -- just provide a list of all the options you want to try adding for
  -- a given 'paysScript' that's already on the transaction.
  --
  -- ###################################
  --
  -- ATTENTION: If you modify the state while computing these lists,
  -- the behaviour of the 'doubleSatAttack' might be strange: Any
  -- modification of the state that happens on any call to this
  -- function will be applied to all returned transactions. For
  -- example, if you 'awaitTime' in any of these computations, the
  -- 'doubleSatAttack' will wait for all returned transactions.
  --
  -- TODO: Make this interface safer, for example by using (some kind
  -- of) an 'UtxoState' argument.
  --
  -- ###################################
  (is -> a -> m [(a, DoubleSatDelta)]) ->
  -- | The wallet of the attacker, where any surplus is paid to.
  --
  -- In the example, the extra value in the added 'spendsScript'
  -- constraints will be paid to the attacker.
  Wallet ->
  m ()
doubleSatAttack :: forall (m :: * -> *) is k a.
(MonadTweak m, Eq is, Is k A_Traversal) =>
([is] -> [[is]])
-> Optic' k (WithIx is) TxSkel a
-> (is -> a -> m [(a, DoubleSatDelta)])
-> Wallet
-> m ()
doubleSatAttack [is] -> [[is]]
groupings Optic' k (WithIx is) TxSkel a
optic is -> a -> m [(a, DoubleSatDelta)]
change Wallet
attacker = do
  [DoubleSatDelta]
deltas <- ([is] -> [[is]])
-> Optic' k (WithIx is) TxSkel a
-> (is -> a -> m [(a, DoubleSatDelta)])
-> m [DoubleSatDelta]
forall is k (m :: * -> *) x l.
(Eq is, Is k A_Traversal, MonadTweak m) =>
([is] -> [[is]])
-> Optic' k (WithIx is) TxSkel x
-> (is -> x -> m [(x, l)])
-> m [l]
combineModsTweak [is] -> [[is]]
groupings Optic' k (WithIx is) TxSkel a
optic is -> a -> m [(a, DoubleSatDelta)]
change
  let delta :: DoubleSatDelta
delta = [DoubleSatDelta] -> DoubleSatDelta
joinDoubleSatDeltas [DoubleSatDelta]
deltas
  DoubleSatDelta -> m ()
forall (m :: * -> *). MonadTweak m => DoubleSatDelta -> m ()
addDoubleSatDeltaTweak DoubleSatDelta
delta
  Value
addedValue <- DoubleSatDelta -> m Value
forall (m :: * -> *). MonadTweak m => DoubleSatDelta -> m Value
deltaBalance DoubleSatDelta
delta
  if Value
addedValue Value -> Value -> Bool
`Script.gt` Value
forall a. Monoid a => a
mempty
    then TxSkelOut -> m ()
forall (m :: * -> *). MonadTweak m => TxSkelOut -> m ()
addOutputTweak (TxSkelOut -> m ()) -> TxSkelOut -> m ()
forall a b. (a -> b) -> a -> b
$ Wallet -> Value -> TxSkelOut
forall a. ToPubKeyHash a => a -> Value -> TxSkelOut
paysPK Wallet
attacker Value
addedValue
    else m ()
forall (m :: * -> *) a. MonadTweak m => m a
failingTweak
  DoubleSatLbl -> m ()
forall (m :: * -> *) x. (MonadTweak m, LabelConstrs x) => x -> m ()
addLabelTweak DoubleSatLbl
DoubleSatLbl
  where
    -- for each triple of additional inputs, outputs, and mints,
    -- calculate its balance
    deltaBalance :: (MonadTweak m) => DoubleSatDelta -> m Api.Value
    deltaBalance :: forall (m :: * -> *). MonadTweak m => DoubleSatDelta -> m Value
deltaBalance (Map TxOutRef TxSkelRedeemer
inputs, [TxSkelOut]
outputs, TxSkelMints
mints) = do
      Value
inValue <- ((TxOutRef, TxOut) -> Value) -> [(TxOutRef, TxOut)] -> Value
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (TxOut -> Value
forall o. (IsAbstractOutput o, ToValue (ValueType o)) => o -> Value
outputValue (TxOut -> Value)
-> ((TxOutRef, TxOut) -> TxOut) -> (TxOutRef, TxOut) -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TxOutRef, TxOut) -> TxOut
forall a b. (a, b) -> b
snd) ([(TxOutRef, TxOut)] -> Value)
-> ([(TxOutRef, TxOut)] -> [(TxOutRef, TxOut)])
-> [(TxOutRef, TxOut)]
-> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((TxOutRef, TxOut) -> Bool)
-> [(TxOutRef, TxOut)] -> [(TxOutRef, TxOut)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((TxOutRef -> [TxOutRef] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Map TxOutRef TxSkelRedeemer -> [TxOutRef]
forall k a. Map k a -> [k]
Map.keys Map TxOutRef TxSkelRedeemer
inputs) (TxOutRef -> Bool)
-> ((TxOutRef, TxOut) -> TxOutRef) -> (TxOutRef, TxOut) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TxOutRef, TxOut) -> TxOutRef
forall a b. (a, b) -> a
fst) ([(TxOutRef, TxOut)] -> Value) -> m [(TxOutRef, TxOut)] -> m Value
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m [(TxOutRef, TxOut)]
forall (m :: * -> *).
MonadBlockChainWithoutValidation m =>
m [(TxOutRef, TxOut)]
allUtxos
      Value -> m Value
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Value -> m Value) -> Value -> m Value
forall a b. (a -> b) -> a -> b
$ Value
inValue Value -> Value -> Value
forall a. Semigroup a => a -> a -> a
<> Value -> Value
forall a. AdditiveGroup a => a -> a
PlutusTx.negate Value
outValue Value -> Value -> Value
forall a. Semigroup a => a -> a -> a
<> Value
mintValue
      where
        outValue :: Value
outValue = Optic' A_Traversal NoIx [TxSkelOut] Value -> [TxSkelOut] -> Value
forall k a (is :: IxList) s.
(Is k A_Fold, Monoid a) =>
Optic' k is s a -> s -> a
foldOf (Traversal [TxSkelOut] [TxSkelOut] TxSkelOut TxSkelOut
forall (t :: * -> *) a b.
Traversable t =>
Traversal (t a) (t b) a b
traversed Traversal [TxSkelOut] [TxSkelOut] TxSkelOut TxSkelOut
-> Optic A_Lens NoIx TxSkelOut TxSkelOut Value Value
-> Optic' A_Traversal NoIx [TxSkelOut] Value
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Optic A_Lens NoIx TxSkelOut TxSkelOut Value Value
txSkelOutValueL) [TxSkelOut]
outputs
        mintValue :: Value
mintValue = TxSkelMints -> Value
txSkelMintsValue TxSkelMints
mints

    -- Helper tweak to add a 'DoubleSatDelta' to a transaction
    addDoubleSatDeltaTweak :: (MonadTweak m) => DoubleSatDelta -> m ()
    addDoubleSatDeltaTweak :: forall (m :: * -> *). MonadTweak m => DoubleSatDelta -> m ()
addDoubleSatDeltaTweak (Map TxOutRef TxSkelRedeemer
ins, [TxSkelOut]
outs, TxSkelMints
mints) =
      ((TxOutRef, TxSkelRedeemer) -> m ())
-> [(TxOutRef, TxSkelRedeemer)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((TxOutRef -> TxSkelRedeemer -> m ())
-> (TxOutRef, TxSkelRedeemer) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry TxOutRef -> TxSkelRedeemer -> m ()
forall (m :: * -> *).
MonadTweak m =>
TxOutRef -> TxSkelRedeemer -> m ()
addInputTweak) (Map TxOutRef TxSkelRedeemer -> [(TxOutRef, TxSkelRedeemer)]
forall k a. Map k a -> [(k, a)]
Map.toList Map TxOutRef TxSkelRedeemer
ins)
        m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (TxSkelOut -> m ()) -> [TxSkelOut] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TxSkelOut -> m ()
forall (m :: * -> *). MonadTweak m => TxSkelOut -> m ()
addOutputTweak [TxSkelOut]
outs
        m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ((Versioned MintingPolicy, TxSkelRedeemer, TokenName, Integer)
 -> m ())
-> [(Versioned MintingPolicy, TxSkelRedeemer, TokenName, Integer)]
-> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Versioned MintingPolicy, TxSkelRedeemer, TokenName, Integer)
-> m ()
forall (m :: * -> *).
MonadTweak m =>
(Versioned MintingPolicy, TxSkelRedeemer, TokenName, Integer)
-> m ()
addMintTweak (TxSkelMints
-> [(Versioned MintingPolicy, TxSkelRedeemer, TokenName, Integer)]
txSkelMintsToList TxSkelMints
mints)

    -- Join a list of 'DoubleSatDelta's into one 'DoubleSatDelta' that
    -- specifies eveything that is contained in the input.
    joinDoubleSatDeltas :: [DoubleSatDelta] -> DoubleSatDelta
    joinDoubleSatDeltas :: [DoubleSatDelta] -> DoubleSatDelta
joinDoubleSatDeltas = [DoubleSatDelta] -> DoubleSatDelta
forall a. Monoid a => [a] -> a
mconcat

data DoubleSatLbl = DoubleSatLbl
  deriving (DoubleSatLbl -> DoubleSatLbl -> Bool
(DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool) -> Eq DoubleSatLbl
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DoubleSatLbl -> DoubleSatLbl -> Bool
== :: DoubleSatLbl -> DoubleSatLbl -> Bool
$c/= :: DoubleSatLbl -> DoubleSatLbl -> Bool
/= :: DoubleSatLbl -> DoubleSatLbl -> Bool
Eq, Int -> DoubleSatLbl -> ShowS
[DoubleSatLbl] -> ShowS
DoubleSatLbl -> String
(Int -> DoubleSatLbl -> ShowS)
-> (DoubleSatLbl -> String)
-> ([DoubleSatLbl] -> ShowS)
-> Show DoubleSatLbl
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DoubleSatLbl -> ShowS
showsPrec :: Int -> DoubleSatLbl -> ShowS
$cshow :: DoubleSatLbl -> String
show :: DoubleSatLbl -> String
$cshowList :: [DoubleSatLbl] -> ShowS
showList :: [DoubleSatLbl] -> ShowS
Show, Eq DoubleSatLbl
Eq DoubleSatLbl =>
(DoubleSatLbl -> DoubleSatLbl -> Ordering)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl)
-> (DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl)
-> Ord DoubleSatLbl
DoubleSatLbl -> DoubleSatLbl -> Bool
DoubleSatLbl -> DoubleSatLbl -> Ordering
DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: DoubleSatLbl -> DoubleSatLbl -> Ordering
compare :: DoubleSatLbl -> DoubleSatLbl -> Ordering
$c< :: DoubleSatLbl -> DoubleSatLbl -> Bool
< :: DoubleSatLbl -> DoubleSatLbl -> Bool
$c<= :: DoubleSatLbl -> DoubleSatLbl -> Bool
<= :: DoubleSatLbl -> DoubleSatLbl -> Bool
$c> :: DoubleSatLbl -> DoubleSatLbl -> Bool
> :: DoubleSatLbl -> DoubleSatLbl -> Bool
$c>= :: DoubleSatLbl -> DoubleSatLbl -> Bool
>= :: DoubleSatLbl -> DoubleSatLbl -> Bool
$cmax :: DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
max :: DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
$cmin :: DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
min :: DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
Ord)

instance PrettyCooked DoubleSatLbl where
  prettyCooked :: DoubleSatLbl -> DocCooked
prettyCooked DoubleSatLbl
_ = DocCooked
"DoubleSat"