{-# OPTIONS_GHC -Wno-orphans #-}

-- | This module provides an automated attack to try and perform double
-- satisfaction on a contract.
module Cooked.Attack.DoubleSat
  ( DoubleSatDelta,
    DoubleSatLbl (..),
    doubleSatAttack,
  )
where

import Cooked.MockChain.BlockChain
import Cooked.Output
import Cooked.Pretty
import Cooked.Skeleton
import Cooked.Tweak
import Cooked.Wallet
import Data.Map (Map)
import Data.Map qualified as Map
import Optics.Core
import PlutusLedgerApi.V1.Value qualified as Api
import PlutusLedgerApi.V3 qualified as Api
import PlutusTx.Numeric qualified as PlutusTx

{- Note: What is a double satisfaction attack?

A double satisfaction attack consists in trying to satisfy the requirements for
what conceptually are two transactions in a single transaction, and doing so
incompletely. It succeeds whenever the requirements of two validators ovelap,
but the required outputs of the transaction are not sufficiently unique, so that
both validators see them as satisfying "their" requirement.

The mechanism is explained very well in the following analogy from the Plutus
documentation: "Suppose that two tax auditors from two different departments
come to visit you in turn to see if you’ve paid your taxes. You come up with a
clever scheme to confuse them. Your tax liability to both departments is $10, so
you make a single payment to the tax office’s bank account for $10. When the
auditors arrive, you show them your books, containing the payment to the tax
office. They both leave satisfied."

The double satisfaction attack 'doubleSatAttack' provided by this module works
by going through the foci of some optic on the 'TxSkel' representing the
transaction from the left to the right, and adding some extra inputs, outputs,
and mints depending on each focus and the current 'MockChainSt'ate. -}

-- | A triplet of transaction inputs, transaction outputs, and minted
-- value. This is what we can add to the transaction in order to try a double
-- satisfaction attack.
type DoubleSatDelta = (Map Api.TxOutRef TxSkelRedeemer, [TxSkelOut], TxSkelMints)

instance {-# OVERLAPPING #-} Semigroup DoubleSatDelta where
  (Map TxOutRef TxSkelRedeemer
i, [TxSkelOut]
o, TxSkelMints
m) <> :: DoubleSatDelta -> DoubleSatDelta -> DoubleSatDelta
<> (Map TxOutRef TxSkelRedeemer
i', [TxSkelOut]
o', TxSkelMints
m') =
    ( Map TxOutRef TxSkelRedeemer
i Map TxOutRef TxSkelRedeemer
-> Map TxOutRef TxSkelRedeemer -> Map TxOutRef TxSkelRedeemer
forall a. Semigroup a => a -> a -> a
<> Map TxOutRef TxSkelRedeemer
i', -- this is left-biased union
      [TxSkelOut]
o [TxSkelOut] -> [TxSkelOut] -> [TxSkelOut]
forall a. [a] -> [a] -> [a]
++ [TxSkelOut]
o',
      TxSkelMints
m TxSkelMints -> TxSkelMints -> TxSkelMints
forall a. Semigroup a => a -> a -> a
<> TxSkelMints
m' -- see the 'Semigroup' instance of 'TxSkelMints'
    )

instance {-# OVERLAPPING #-} Monoid DoubleSatDelta where
  mempty :: DoubleSatDelta
mempty = (Map TxOutRef TxSkelRedeemer
forall k a. Map k a
Map.empty, [], TxSkelMints
forall a. Monoid a => a
mempty)

-- | Double satisfaction attack. See the comment above for what such an
-- attack is about conceptually.
--
-- This attack consists in adding some extra constraints to a transaction, and
-- hoping that the additional minting policies or validator scripts thereby
-- involved are fooled by what's already present on the transaction. Any extra
-- value contained in new inputs to the transaction is then paid to the
-- attacker.
doubleSatAttack ::
  (MonadTweak m, Eq is, Is k A_Traversal) =>
  -- | how to combine modifications from caused by different foci. See the
  -- comment at 'combineModsTweak', which uses the same logic.
  ([is] -> [[is]]) ->
  -- | Each focus of this optic is a potential reason to add some extra
  -- constraints.
  Optic' k (WithIx is) TxSkel a ->
  -- | How to change each focus, and which inputs, outputs, and mints to add,
  -- for each of the foci. There might be different options for each focus,
  -- that's why the return value is a list.
  --
  -- Continuing the example, for each of the focused script outputs, you might
  -- want to try adding some script inputs to the transaction. Since it might be
  -- interesting to try different redeemers on these extra script inputs, you
  -- can just provide a list of all the options you want to try adding for a
  -- given script output that's already on the transaction.
  --
  -- ###################################
  --
  -- ATTENTION: If you modify the state while computing these lists, the
  -- behaviour of the 'doubleSatAttack' might be strange: Any modification of
  -- the state that happens on any call to this function will be applied to all
  -- returned transactions. For example, if you
  -- 'Cooked.MockChain.BlockChain.awaitSlot' in any of these computations, the
  -- 'doubleSatAttack' will wait for all returned transactions.
  --
  -- TODO: Make this interface safer, for example by using (some kind of) an
  -- 'Cooked.MockChain.UtxoState.UtxoState' argument.
  --
  -- ###################################
  (is -> a -> m [(a, DoubleSatDelta)]) ->
  -- | The wallet of the attacker, where any surplus is paid to.
  --
  -- In the example, the extra value in the added input will be paid to the
  -- attacker.
  Wallet ->
  m ()
doubleSatAttack :: forall (m :: * -> *) is k a.
(MonadTweak m, Eq is, Is k A_Traversal) =>
([is] -> [[is]])
-> Optic' k (WithIx is) TxSkel a
-> (is -> a -> m [(a, DoubleSatDelta)])
-> Wallet
-> m ()
doubleSatAttack [is] -> [[is]]
groupings Optic' k (WithIx is) TxSkel a
optic is -> a -> m [(a, DoubleSatDelta)]
change Wallet
attacker = do
  [DoubleSatDelta]
deltas <- ([is] -> [[is]])
-> Optic' k (WithIx is) TxSkel a
-> (is -> a -> m [(a, DoubleSatDelta)])
-> m [DoubleSatDelta]
forall is k (m :: * -> *) x l.
(Eq is, Is k A_Traversal, MonadTweak m) =>
([is] -> [[is]])
-> Optic' k (WithIx is) TxSkel x
-> (is -> x -> m [(x, l)])
-> m [l]
combineModsTweak [is] -> [[is]]
groupings Optic' k (WithIx is) TxSkel a
optic is -> a -> m [(a, DoubleSatDelta)]
change
  let delta :: DoubleSatDelta
delta = [DoubleSatDelta] -> DoubleSatDelta
joinDoubleSatDeltas [DoubleSatDelta]
deltas
  DoubleSatDelta -> m ()
forall (m :: * -> *). MonadTweak m => DoubleSatDelta -> m ()
addDoubleSatDeltaTweak DoubleSatDelta
delta
  Value
addedValue <- DoubleSatDelta -> m Value
forall (m :: * -> *). MonadTweak m => DoubleSatDelta -> m Value
deltaBalance DoubleSatDelta
delta
  if Value
addedValue Value -> Value -> Bool
`Api.gt` Value
forall a. Monoid a => a
mempty
    then TxSkelOut -> m ()
forall (m :: * -> *). MonadTweak m => TxSkelOut -> m ()
addOutputTweak (TxSkelOut -> m ()) -> TxSkelOut -> m ()
forall a b. (a -> b) -> a -> b
$ Wallet
attacker Wallet -> Payable '["Value"] -> TxSkelOut
forall owner (els :: [Symbol]).
(Show owner, Typeable owner, IsTxSkelOutAllowedOwner owner,
 ToCredential owner) =>
owner -> Payable els -> TxSkelOut
`receives` Value -> Payable '["Value"]
forall a1. ToValue a1 => a1 -> Payable '["Value"]
Value Value
addedValue
    else m ()
forall (m :: * -> *) a. MonadTweak m => m a
failingTweak
  DoubleSatLbl -> m ()
forall (m :: * -> *) x. (MonadTweak m, LabelConstrs x) => x -> m ()
addLabelTweak DoubleSatLbl
DoubleSatLbl
  where
    -- for each triple of additional inputs, outputs, and mints,
    -- calculate its balance
    deltaBalance :: (MonadTweak m) => DoubleSatDelta -> m Api.Value
    deltaBalance :: forall (m :: * -> *). MonadTweak m => DoubleSatDelta -> m Value
deltaBalance (Map TxOutRef TxSkelRedeemer
inputs, [TxSkelOut]
outputs, TxSkelMints
mints) = do
      Value
inValue <- ((TxOutRef, TxOut) -> Value) -> [(TxOutRef, TxOut)] -> Value
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (TxOut -> Value
forall o. (IsAbstractOutput o, ToValue (ValueType o)) => o -> Value
outputValue (TxOut -> Value)
-> ((TxOutRef, TxOut) -> TxOut) -> (TxOutRef, TxOut) -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TxOutRef, TxOut) -> TxOut
forall a b. (a, b) -> b
snd) ([(TxOutRef, TxOut)] -> Value)
-> ([(TxOutRef, TxOut)] -> [(TxOutRef, TxOut)])
-> [(TxOutRef, TxOut)]
-> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((TxOutRef, TxOut) -> Bool)
-> [(TxOutRef, TxOut)] -> [(TxOutRef, TxOut)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((TxOutRef -> [TxOutRef] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Map TxOutRef TxSkelRedeemer -> [TxOutRef]
forall k a. Map k a -> [k]
Map.keys Map TxOutRef TxSkelRedeemer
inputs) (TxOutRef -> Bool)
-> ((TxOutRef, TxOut) -> TxOutRef) -> (TxOutRef, TxOut) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TxOutRef, TxOut) -> TxOutRef
forall a b. (a, b) -> a
fst) ([(TxOutRef, TxOut)] -> Value) -> m [(TxOutRef, TxOut)] -> m Value
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m [(TxOutRef, TxOut)]
forall (m :: * -> *).
MonadBlockChainWithoutValidation m =>
m [(TxOutRef, TxOut)]
allUtxos
      Value -> m Value
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Value -> m Value) -> Value -> m Value
forall a b. (a -> b) -> a -> b
$ Value
inValue Value -> Value -> Value
forall a. Semigroup a => a -> a -> a
<> Value -> Value
forall a. AdditiveGroup a => a -> a
PlutusTx.negate Value
outValue Value -> Value -> Value
forall a. Semigroup a => a -> a -> a
<> Value
mintValue
      where
        outValue :: Value
outValue = Optic' A_Traversal '[] [TxSkelOut] Value -> [TxSkelOut] -> Value
forall k a (is :: IxList) s.
(Is k A_Fold, Monoid a) =>
Optic' k is s a -> s -> a
foldOf (Traversal [TxSkelOut] [TxSkelOut] TxSkelOut TxSkelOut
forall (t :: * -> *) a b.
Traversable t =>
Traversal (t a) (t b) a b
traversed Traversal [TxSkelOut] [TxSkelOut] TxSkelOut TxSkelOut
-> Optic
     A_Lens '[] TxSkelOut TxSkelOut TxSkelOutValue TxSkelOutValue
-> Optic
     A_Traversal
     '[]
     [TxSkelOut]
     [TxSkelOut]
     TxSkelOutValue
     TxSkelOutValue
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Optic A_Lens '[] TxSkelOut TxSkelOut TxSkelOutValue TxSkelOutValue
txSkelOutValueL Optic
  A_Traversal
  '[]
  [TxSkelOut]
  [TxSkelOut]
  TxSkelOutValue
  TxSkelOutValue
-> Optic A_Lens '[] TxSkelOutValue TxSkelOutValue Value Value
-> Optic' A_Traversal '[] [TxSkelOut] Value
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Optic A_Lens '[] TxSkelOutValue TxSkelOutValue Value Value
txSkelOutValueContentL) [TxSkelOut]
outputs
        mintValue :: Value
mintValue = TxSkelMints -> Value
txSkelMintsValue TxSkelMints
mints

    -- Helper tweak to add a 'DoubleSatDelta' to a transaction
    addDoubleSatDeltaTweak :: (MonadTweak m) => DoubleSatDelta -> m ()
    addDoubleSatDeltaTweak :: forall (m :: * -> *). MonadTweak m => DoubleSatDelta -> m ()
addDoubleSatDeltaTweak (Map TxOutRef TxSkelRedeemer
ins, [TxSkelOut]
outs, TxSkelMints
mints) =
      ((TxOutRef, TxSkelRedeemer) -> m ())
-> [(TxOutRef, TxSkelRedeemer)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((TxOutRef -> TxSkelRedeemer -> m ())
-> (TxOutRef, TxSkelRedeemer) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry TxOutRef -> TxSkelRedeemer -> m ()
forall (m :: * -> *).
MonadTweak m =>
TxOutRef -> TxSkelRedeemer -> m ()
addInputTweak) (Map TxOutRef TxSkelRedeemer -> [(TxOutRef, TxSkelRedeemer)]
forall k a. Map k a -> [(k, a)]
Map.toList Map TxOutRef TxSkelRedeemer
ins)
        m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (TxSkelOut -> m ()) -> [TxSkelOut] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TxSkelOut -> m ()
forall (m :: * -> *). MonadTweak m => TxSkelOut -> m ()
addOutputTweak [TxSkelOut]
outs
        m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Mint -> m ()) -> [Mint] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Mint -> m ()
forall (m :: * -> *). MonadTweak m => Mint -> m ()
addMintTweak (TxSkelMints -> [Mint]
txSkelMintsToList TxSkelMints
mints)

    -- Join a list of 'DoubleSatDelta's into one 'DoubleSatDelta' that specifies
    -- eveything that is contained in the input.
    joinDoubleSatDeltas :: [DoubleSatDelta] -> DoubleSatDelta
    joinDoubleSatDeltas :: [DoubleSatDelta] -> DoubleSatDelta
joinDoubleSatDeltas = [DoubleSatDelta] -> DoubleSatDelta
forall a. Monoid a => [a] -> a
mconcat

-- | A label that is added to a 'TxSkel' that has successfully been modified by
-- the 'doubleSatAttack'
data DoubleSatLbl = DoubleSatLbl
  deriving (DoubleSatLbl -> DoubleSatLbl -> Bool
(DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool) -> Eq DoubleSatLbl
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DoubleSatLbl -> DoubleSatLbl -> Bool
== :: DoubleSatLbl -> DoubleSatLbl -> Bool
$c/= :: DoubleSatLbl -> DoubleSatLbl -> Bool
/= :: DoubleSatLbl -> DoubleSatLbl -> Bool
Eq, Int -> DoubleSatLbl -> ShowS
[DoubleSatLbl] -> ShowS
DoubleSatLbl -> String
(Int -> DoubleSatLbl -> ShowS)
-> (DoubleSatLbl -> String)
-> ([DoubleSatLbl] -> ShowS)
-> Show DoubleSatLbl
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DoubleSatLbl -> ShowS
showsPrec :: Int -> DoubleSatLbl -> ShowS
$cshow :: DoubleSatLbl -> String
show :: DoubleSatLbl -> String
$cshowList :: [DoubleSatLbl] -> ShowS
showList :: [DoubleSatLbl] -> ShowS
Show, Eq DoubleSatLbl
Eq DoubleSatLbl =>
(DoubleSatLbl -> DoubleSatLbl -> Ordering)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> Bool)
-> (DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl)
-> (DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl)
-> Ord DoubleSatLbl
DoubleSatLbl -> DoubleSatLbl -> Bool
DoubleSatLbl -> DoubleSatLbl -> Ordering
DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: DoubleSatLbl -> DoubleSatLbl -> Ordering
compare :: DoubleSatLbl -> DoubleSatLbl -> Ordering
$c< :: DoubleSatLbl -> DoubleSatLbl -> Bool
< :: DoubleSatLbl -> DoubleSatLbl -> Bool
$c<= :: DoubleSatLbl -> DoubleSatLbl -> Bool
<= :: DoubleSatLbl -> DoubleSatLbl -> Bool
$c> :: DoubleSatLbl -> DoubleSatLbl -> Bool
> :: DoubleSatLbl -> DoubleSatLbl -> Bool
$c>= :: DoubleSatLbl -> DoubleSatLbl -> Bool
>= :: DoubleSatLbl -> DoubleSatLbl -> Bool
$cmax :: DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
max :: DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
$cmin :: DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
min :: DoubleSatLbl -> DoubleSatLbl -> DoubleSatLbl
Ord)

instance PrettyCooked DoubleSatLbl where
  prettyCooked :: DoubleSatLbl -> DocCooked
prettyCooked DoubleSatLbl
_ = DocCooked
"DoubleSat"