-- |
-- Module      : Data.Memory.PtrMethods
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- methods to manipulate raw memory representation
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Memory.PtrMethods
    ( memCreateTemporary
    , memXor
    , memXorWith
    , memCopy
    , memSet
    , memReverse
    , memEqual
    , memConstEqual
    , memCompare
    ) where

import           Data.Memory.Internal.Imports
import           Foreign.Ptr              (Ptr, plusPtr)
import           Foreign.Storable         (peek, poke, peekByteOff)
import           Foreign.C.Types
import           Foreign.Marshal.Alloc    (allocaBytesAligned)
import           Data.Bits                ((.|.), xor)

-- | Create a new temporary buffer
memCreateTemporary :: Int -> (Ptr Word8 -> IO a) -> IO a
memCreateTemporary :: Int -> (Ptr Word8 -> IO a) -> IO a
memCreateTemporary Int
size Ptr Word8 -> IO a
f = Int -> Int -> (Ptr Word8 -> IO a) -> IO a
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
size Int
8 Ptr Word8 -> IO a
f

-- | xor bytes from source1 and source2 to destination
-- 
-- d = s1 xor s2
--
-- s1, nor s2 are modified unless d point to s1 or s2
memXor :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor Ptr Word8
_ Ptr Word8
_  Ptr Word8
_  Int
0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
memXor Ptr Word8
d Ptr Word8
s1 Ptr Word8
s2 Int
n = do
    (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor (Word8 -> Word8 -> Word8) -> IO Word8 -> IO (Word8 -> Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s1 IO (Word8 -> Word8) -> IO Word8 -> IO Word8
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s2) IO Word8 -> (Word8 -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
d
    Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor (Ptr Word8
d Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr Word8
s1 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr Word8
s2 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)

-- | xor bytes from source with a specific value to destination
--
-- d = replicate (sizeof s) v `xor` s
memXorWith :: Ptr Word8 -> Word8 -> Ptr Word8 -> Int -> IO ()
memXorWith :: Ptr Word8 -> Word8 -> Ptr Word8 -> Int -> IO ()
memXorWith Ptr Word8
destination !Word8
v Ptr Word8
source Int
bytes
    | Ptr Word8
destination Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
source = Ptr Word8 -> Int -> IO ()
forall t. (Eq t, Num t) => Ptr Word8 -> t -> IO ()
loopInplace Ptr Word8
source Int
bytes
    | Bool
otherwise             = Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall t. (Eq t, Num t) => Ptr Word8 -> Ptr Word8 -> t -> IO ()
loop Ptr Word8
destination Ptr Word8
source Int
bytes
  where
    loop :: Ptr Word8 -> Ptr Word8 -> t -> IO ()
loop Ptr Word8
_   Ptr Word8
_  t
0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    loop !Ptr Word8
d !Ptr Word8
s !t
n = do
        Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s IO Word8 -> (Word8 -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
d (Word8 -> IO ()) -> (Word8 -> Word8) -> Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
v
        Ptr Word8 -> Ptr Word8 -> t -> IO ()
loop (Ptr Word8
d Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr Word8
s Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1)

    loopInplace :: Ptr Word8 -> t -> IO ()
loopInplace Ptr Word8
_   t
0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    loopInplace !Ptr Word8
s !t
n = do
        Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s IO Word8 -> (Word8 -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
s (Word8 -> IO ()) -> (Word8 -> Word8) -> Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
v
        Ptr Word8 -> t -> IO ()
loopInplace (Ptr Word8
s Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1)

-- | Copy a set number of bytes from @src to @dst
memCopy :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
dst Ptr Word8
src Int
n = Ptr Word8 -> Ptr Word8 -> CSize -> IO ()
c_memcpy Ptr Word8
dst Ptr Word8
src (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
{-# INLINE memCopy #-}

-- | Set @n number of bytes to the same value @v
memSet :: Ptr Word8 -> Word8 -> Int -> IO ()
memSet :: Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
start Word8
v Int
n = Ptr Word8 -> Word8 -> CSize -> IO ()
c_memset Ptr Word8
start Word8
v (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) IO () -> (() -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \()
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE memSet #-}

-- | Reverse a set number of bytes from @src@ to @dst@.  Memory
-- locations should not overlap.
memReverse :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memReverse :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memReverse Ptr Word8
d Ptr Word8
s Int
n
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do Ptr Word8 -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
s (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) IO Word8 -> (Word8 -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
d
                 Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memReverse (Ptr Word8
d Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
s (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    | Bool
otherwise = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Check if two piece of memory are equals
memEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memEqual Ptr Word8
p1 Ptr Word8
p2 Int
n = Int -> IO Bool
loop Int
0
  where
    loop :: Int -> IO Bool
loop Int
i
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n    = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        | Bool
otherwise = do
            Bool
e <- Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Word8 -> Word8 -> Bool) -> IO Word8 -> IO (Word8 -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p1 Int
i IO (Word8 -> Bool) -> IO Word8 -> IO Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Ptr Word8 -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p2 Int
i :: IO Word8)
            if Bool
e then Int -> IO Bool
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) else Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | Compare two piece of memory and returns how they compare
memCompare :: Ptr Word8 -> Ptr Word8 -> Int -> IO Ordering
memCompare :: Ptr Word8 -> Ptr Word8 -> Int -> IO Ordering
memCompare Ptr Word8
p1 Ptr Word8
p2 Int
n = Int -> IO Ordering
loop Int
0
  where
    loop :: Int -> IO Ordering
loop Int
i
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n    = Ordering -> IO Ordering
forall (m :: * -> *) a. Monad m => a -> m a
return Ordering
EQ
        | Bool
otherwise = do
            Ordering
e <- Word8 -> Word8 -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Word8 -> Word8 -> Ordering) -> IO Word8 -> IO (Word8 -> Ordering)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p1 Int
i IO (Word8 -> Ordering) -> IO Word8 -> IO Ordering
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Ptr Word8 -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p2 Int
i :: IO Word8)
            if Ordering
e Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ then Int -> IO Ordering
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) else Ordering -> IO Ordering
forall (m :: * -> *) a. Monad m => a -> m a
return Ordering
e

-- | A constant time equality test for 2 Memory buffers
--
-- compared to normal equality function, this function will go
-- over all the bytes present before yielding a result even when
-- knowing the overall result early in the processing.
memConstEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memConstEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memConstEqual Ptr Word8
p1 Ptr Word8
p2 Int
n = Int -> Word8 -> IO Bool
loop Int
0 Word8
0
  where
    loop :: Int -> Word8 -> IO Bool
loop Int
i !Word8
acc
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n    = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$! Word8
acc Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0
        | Bool
otherwise = do
            Word8
e <- Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor (Word8 -> Word8 -> Word8) -> IO Word8 -> IO (Word8 -> Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p1 Int
i IO (Word8 -> Word8) -> IO Word8 -> IO Word8
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Ptr Word8 -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p2 Int
i :: IO Word8)
            Int -> Word8 -> IO Bool
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Word8
acc Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
e)

foreign import ccall unsafe "memset"
    c_memset :: Ptr Word8 -> Word8 -> CSize -> IO ()

foreign import ccall unsafe "memcpy"
    c_memcpy :: Ptr Word8 -> Ptr Word8 -> CSize -> IO ()