{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- Module      : Network.Connection
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : portable
--
-- Simple connection abstraction
--
module Network.Connection
    (
    -- * Type for a connection
      Connection
    , connectionID
    , ConnectionParams(..)
    , TLSSettings(..)
    , ProxySettings(..)
    , SockSettings

    -- * Exceptions
    , LineTooLong(..)
    , HostNotResolved(..)
    , HostCannotConnect(..)

    -- * Library initialization
    , initConnectionContext
    , ConnectionContext

    -- * Connection operation
    , connectFromHandle
    , connectFromSocket
    , connectTo
    , connectionClose

    -- * Sending and receiving data
    , connectionGet
    , connectionGetExact
    , connectionGetChunk
    , connectionGetChunk'
    , connectionGetLine
    , connectionWaitForInput
    , connectionPut

    -- * TLS related operation
    , connectionSetSecure
    , connectionIsSecure
    , connectionSessionManager
    ) where

import Control.Concurrent.MVar
import Control.Monad (join)
import qualified Control.Exception as E
import qualified System.IO.Error as E (mkIOError, eofErrorType)

import qualified Network.TLS as TLS
import qualified Network.TLS.Extra as TLS

import System.X509 (getSystemCertificateStore)

import Network.Socks5 (defaultSocksConf, socksConnectWithSocket, SocksAddress(..), SocksHostAddress(..))
import Network.Socket
import qualified Network.Socket.ByteString as N

import Data.Tuple (swap)
import Data.Default.Class
import Data.Data
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as L

import System.Environment
import System.Timeout
import System.IO
import qualified Data.Map as M

import Network.Connection.Types

type Manager = MVar (M.Map TLS.SessionID TLS.SessionData)

-- | This is the exception raised if we reached the user specified limit for
-- the line in ConnectionGetLine.
data LineTooLong = LineTooLong deriving (Int -> LineTooLong -> ShowS
[LineTooLong] -> ShowS
LineTooLong -> String
(Int -> LineTooLong -> ShowS)
-> (LineTooLong -> String)
-> ([LineTooLong] -> ShowS)
-> Show LineTooLong
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LineTooLong] -> ShowS
$cshowList :: [LineTooLong] -> ShowS
show :: LineTooLong -> String
$cshow :: LineTooLong -> String
showsPrec :: Int -> LineTooLong -> ShowS
$cshowsPrec :: Int -> LineTooLong -> ShowS
Show,Typeable)

-- | Exception raised when there's no resolution for a specific host
data HostNotResolved = HostNotResolved String deriving (Int -> HostNotResolved -> ShowS
[HostNotResolved] -> ShowS
HostNotResolved -> String
(Int -> HostNotResolved -> ShowS)
-> (HostNotResolved -> String)
-> ([HostNotResolved] -> ShowS)
-> Show HostNotResolved
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HostNotResolved] -> ShowS
$cshowList :: [HostNotResolved] -> ShowS
show :: HostNotResolved -> String
$cshow :: HostNotResolved -> String
showsPrec :: Int -> HostNotResolved -> ShowS
$cshowsPrec :: Int -> HostNotResolved -> ShowS
Show,Typeable)

-- | Exception raised when the connect failed
data HostCannotConnect = HostCannotConnect String [E.IOException] deriving (Int -> HostCannotConnect -> ShowS
[HostCannotConnect] -> ShowS
HostCannotConnect -> String
(Int -> HostCannotConnect -> ShowS)
-> (HostCannotConnect -> String)
-> ([HostCannotConnect] -> ShowS)
-> Show HostCannotConnect
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HostCannotConnect] -> ShowS
$cshowList :: [HostCannotConnect] -> ShowS
show :: HostCannotConnect -> String
$cshow :: HostCannotConnect -> String
showsPrec :: Int -> HostCannotConnect -> ShowS
$cshowsPrec :: Int -> HostCannotConnect -> ShowS
Show,Typeable)

instance E.Exception LineTooLong
instance E.Exception HostNotResolved
instance E.Exception HostCannotConnect

connectionSessionManager :: Manager -> TLS.SessionManager
connectionSessionManager :: Manager -> SessionManager
connectionSessionManager Manager
mvar = SessionManager :: (SessionID -> IO (Maybe SessionData))
-> (SessionID -> IO (Maybe SessionData))
-> (SessionID -> SessionData -> IO ())
-> (SessionID -> IO ())
-> SessionManager
TLS.SessionManager
    { sessionResume :: SessionID -> IO (Maybe SessionData)
TLS.sessionResume     = \SessionID
sessionID -> Manager
-> (Map SessionID SessionData -> IO (Maybe SessionData))
-> IO (Maybe SessionData)
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar Manager
mvar (Maybe SessionData -> IO (Maybe SessionData)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe SessionData -> IO (Maybe SessionData))
-> (Map SessionID SessionData -> Maybe SessionData)
-> Map SessionID SessionData
-> IO (Maybe SessionData)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionID -> Map SessionID SessionData -> Maybe SessionData
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup SessionID
sessionID)
    , sessionEstablish :: SessionID -> SessionData -> IO ()
TLS.sessionEstablish  = \SessionID
sessionID SessionData
sessionData ->
                               Manager
-> (Map SessionID SessionData -> IO (Map SessionID SessionData))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ Manager
mvar (Map SessionID SessionData -> IO (Map SessionID SessionData)
forall (m :: * -> *) a. Monad m => a -> m a
return (Map SessionID SessionData -> IO (Map SessionID SessionData))
-> (Map SessionID SessionData -> Map SessionID SessionData)
-> Map SessionID SessionData
-> IO (Map SessionID SessionData)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionID
-> SessionData
-> Map SessionID SessionData
-> Map SessionID SessionData
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert SessionID
sessionID SessionData
sessionData)
    , sessionInvalidate :: SessionID -> IO ()
TLS.sessionInvalidate = \SessionID
sessionID -> Manager
-> (Map SessionID SessionData -> IO (Map SessionID SessionData))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ Manager
mvar (Map SessionID SessionData -> IO (Map SessionID SessionData)
forall (m :: * -> *) a. Monad m => a -> m a
return (Map SessionID SessionData -> IO (Map SessionID SessionData))
-> (Map SessionID SessionData -> Map SessionID SessionData)
-> Map SessionID SessionData
-> IO (Map SessionID SessionData)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionID -> Map SessionID SessionData -> Map SessionID SessionData
forall k a. Ord k => k -> Map k a -> Map k a
M.delete SessionID
sessionID)
#if MIN_VERSION_tls(1,5,0)
    , sessionResumeOnlyOnce :: SessionID -> IO (Maybe SessionData)
TLS.sessionResumeOnlyOnce = \SessionID
sessionID ->
         Manager
-> (Map SessionID SessionData
    -> IO (Map SessionID SessionData, Maybe SessionData))
-> IO (Maybe SessionData)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar Manager
mvar ((Map SessionID SessionData, Maybe SessionData)
-> IO (Map SessionID SessionData, Maybe SessionData)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Map SessionID SessionData, Maybe SessionData)
 -> IO (Map SessionID SessionData, Maybe SessionData))
-> (Map SessionID SessionData
    -> (Map SessionID SessionData, Maybe SessionData))
-> Map SessionID SessionData
-> IO (Map SessionID SessionData, Maybe SessionData)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe SessionData, Map SessionID SessionData)
-> (Map SessionID SessionData, Maybe SessionData)
forall a b. (a, b) -> (b, a)
swap ((Maybe SessionData, Map SessionID SessionData)
 -> (Map SessionID SessionData, Maybe SessionData))
-> (Map SessionID SessionData
    -> (Maybe SessionData, Map SessionID SessionData))
-> Map SessionID SessionData
-> (Map SessionID SessionData, Maybe SessionData)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SessionID -> SessionData -> Maybe SessionData)
-> SessionID
-> Map SessionID SessionData
-> (Maybe SessionData, Map SessionID SessionData)
forall k a.
Ord k =>
(k -> a -> Maybe a) -> k -> Map k a -> (Maybe a, Map k a)
M.updateLookupWithKey (\SessionID
_ SessionData
_ -> Maybe SessionData
forall a. Maybe a
Nothing) SessionID
sessionID)
#endif
    }

-- | Initialize the library with shared parameters between connection.
initConnectionContext :: IO ConnectionContext
initConnectionContext :: IO ConnectionContext
initConnectionContext = CertificateStore -> ConnectionContext
ConnectionContext (CertificateStore -> ConnectionContext)
-> IO CertificateStore -> IO ConnectionContext
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO CertificateStore
getSystemCertificateStore

-- | Create a final TLS 'ClientParams' according to the destination and the
-- TLSSettings.
makeTLSParams :: ConnectionContext -> ConnectionID -> TLSSettings -> TLS.ClientParams
makeTLSParams :: ConnectionContext -> ConnectionID -> TLSSettings -> ClientParams
makeTLSParams ConnectionContext
cg ConnectionID
cid ts :: TLSSettings
ts@(TLSSettingsSimple {}) =
    (String -> SessionID -> ClientParams
TLS.defaultParamsClient (ConnectionID -> String
forall a b. (a, b) -> a
fst ConnectionID
cid) SessionID
portString)
        { clientSupported :: Supported
TLS.clientSupported = Supported
forall a. Default a => a
def { supportedCiphers :: [Cipher]
TLS.supportedCiphers = [Cipher]
TLS.ciphersuite_default }
        , clientShared :: Shared
TLS.clientShared    = Shared
forall a. Default a => a
def
            { sharedCAStore :: CertificateStore
TLS.sharedCAStore         = ConnectionContext -> CertificateStore
globalCertificateStore ConnectionContext
cg
            , sharedValidationCache :: ValidationCache
TLS.sharedValidationCache = ValidationCache
validationCache
            -- , TLS.sharedSessionManager  = connectionSessionManager
            }
        }
  where validationCache :: ValidationCache
validationCache
            | TLSSettings -> Bool
settingDisableCertificateValidation TLSSettings
ts =
                ValidationCacheQueryCallback
-> ValidationCacheAddCallback -> ValidationCache
TLS.ValidationCache (\ServiceID
_ Fingerprint
_ Certificate
_ -> ValidationCacheResult -> IO ValidationCacheResult
forall (m :: * -> *) a. Monad m => a -> m a
return ValidationCacheResult
TLS.ValidationCachePass)
                                    (\ServiceID
_ Fingerprint
_ Certificate
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
            | Bool
otherwise = ValidationCache
forall a. Default a => a
def
        portString :: SessionID
portString = String -> SessionID
BC.pack (String -> SessionID) -> String -> SessionID
forall a b. (a -> b) -> a -> b
$ PortNumber -> String
forall a. Show a => a -> String
show (PortNumber -> String) -> PortNumber -> String
forall a b. (a -> b) -> a -> b
$ ConnectionID -> PortNumber
forall a b. (a, b) -> b
snd ConnectionID
cid
makeTLSParams ConnectionContext
_ ConnectionID
cid (TLSSettings ClientParams
p) =
    ClientParams
p { clientServerIdentification :: ServiceID
TLS.clientServerIdentification = (ConnectionID -> String
forall a b. (a, b) -> a
fst ConnectionID
cid, SessionID
portString) }
 where portString :: SessionID
portString = String -> SessionID
BC.pack (String -> SessionID) -> String -> SessionID
forall a b. (a -> b) -> a -> b
$ PortNumber -> String
forall a. Show a => a -> String
show (PortNumber -> String) -> PortNumber -> String
forall a b. (a -> b) -> a -> b
$ ConnectionID -> PortNumber
forall a b. (a, b) -> b
snd ConnectionID
cid

withBackend :: (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend :: (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend ConnectionBackend -> IO a
f Connection
conn = MVar ConnectionBackend -> IO ConnectionBackend
forall a. MVar a -> IO a
readMVar (Connection -> MVar ConnectionBackend
connectionBackend Connection
conn) IO ConnectionBackend -> (ConnectionBackend -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ConnectionBackend -> IO a
f

connectionNew :: ConnectionID -> ConnectionBackend -> IO Connection
connectionNew :: ConnectionID -> ConnectionBackend -> IO Connection
connectionNew ConnectionID
cid ConnectionBackend
backend =
    MVar ConnectionBackend
-> MVar (Maybe SessionID) -> ConnectionID -> Connection
Connection (MVar ConnectionBackend
 -> MVar (Maybe SessionID) -> ConnectionID -> Connection)
-> IO (MVar ConnectionBackend)
-> IO (MVar (Maybe SessionID) -> ConnectionID -> Connection)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConnectionBackend -> IO (MVar ConnectionBackend)
forall a. a -> IO (MVar a)
newMVar ConnectionBackend
backend
               IO (MVar (Maybe SessionID) -> ConnectionID -> Connection)
-> IO (MVar (Maybe SessionID)) -> IO (ConnectionID -> Connection)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe SessionID -> IO (MVar (Maybe SessionID))
forall a. a -> IO (MVar a)
newMVar (SessionID -> Maybe SessionID
forall a. a -> Maybe a
Just SessionID
B.empty)
               IO (ConnectionID -> Connection) -> IO ConnectionID -> IO Connection
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ConnectionID -> IO ConnectionID
forall (f :: * -> *) a. Applicative f => a -> f a
pure ConnectionID
cid

-- | Use an already established handle to create a connection object.
--
-- if the TLS Settings is set, it will do the handshake with the server.
-- The SOCKS settings have no impact here, as the handle is already established
connectFromHandle :: ConnectionContext
                  -> Handle
                  -> ConnectionParams
                  -> IO Connection
connectFromHandle :: ConnectionContext -> Handle -> ConnectionParams -> IO Connection
connectFromHandle ConnectionContext
cg Handle
h ConnectionParams
p = Maybe TLSSettings -> IO Connection
withSecurity (ConnectionParams -> Maybe TLSSettings
connectionUseSecure ConnectionParams
p)
    where withSecurity :: Maybe TLSSettings -> IO Connection
withSecurity Maybe TLSSettings
Nothing            = ConnectionID -> ConnectionBackend -> IO Connection
connectionNew ConnectionID
cid (ConnectionBackend -> IO Connection)
-> ConnectionBackend -> IO Connection
forall a b. (a -> b) -> a -> b
$ Handle -> ConnectionBackend
ConnectionStream Handle
h
          withSecurity (Just TLSSettings
tlsSettings) = Handle -> ClientParams -> IO Context
forall backend.
HasBackend backend =>
backend -> ClientParams -> IO Context
tlsEstablish Handle
h (ConnectionContext -> ConnectionID -> TLSSettings -> ClientParams
makeTLSParams ConnectionContext
cg ConnectionID
cid TLSSettings
tlsSettings) IO Context -> (Context -> IO Connection) -> IO Connection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ConnectionID -> ConnectionBackend -> IO Connection
connectionNew ConnectionID
cid (ConnectionBackend -> IO Connection)
-> (Context -> ConnectionBackend) -> Context -> IO Connection
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> ConnectionBackend
ConnectionTLS
          cid :: ConnectionID
cid = (ConnectionParams -> String
connectionHostname ConnectionParams
p, ConnectionParams -> PortNumber
connectionPort ConnectionParams
p)

-- | Use an already established handle to create a connection object.
--
-- if the TLS Settings is set, it will do the handshake with the server.
-- The SOCKS settings have no impact here, as the handle is already established
connectFromSocket :: ConnectionContext
                  -> Socket
                  -> ConnectionParams
                  -> IO Connection
connectFromSocket :: ConnectionContext -> Socket -> ConnectionParams -> IO Connection
connectFromSocket ConnectionContext
cg Socket
sock ConnectionParams
p = Maybe TLSSettings -> IO Connection
withSecurity (ConnectionParams -> Maybe TLSSettings
connectionUseSecure ConnectionParams
p)
    where withSecurity :: Maybe TLSSettings -> IO Connection
withSecurity Maybe TLSSettings
Nothing            = ConnectionID -> ConnectionBackend -> IO Connection
connectionNew ConnectionID
cid (ConnectionBackend -> IO Connection)
-> ConnectionBackend -> IO Connection
forall a b. (a -> b) -> a -> b
$ Socket -> ConnectionBackend
ConnectionSocket Socket
sock
          withSecurity (Just TLSSettings
tlsSettings) = Socket -> ClientParams -> IO Context
forall backend.
HasBackend backend =>
backend -> ClientParams -> IO Context
tlsEstablish Socket
sock (ConnectionContext -> ConnectionID -> TLSSettings -> ClientParams
makeTLSParams ConnectionContext
cg ConnectionID
cid TLSSettings
tlsSettings) IO Context -> (Context -> IO Connection) -> IO Connection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ConnectionID -> ConnectionBackend -> IO Connection
connectionNew ConnectionID
cid (ConnectionBackend -> IO Connection)
-> (Context -> ConnectionBackend) -> Context -> IO Connection
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> ConnectionBackend
ConnectionTLS
          cid :: ConnectionID
cid = (ConnectionParams -> String
connectionHostname ConnectionParams
p, ConnectionParams -> PortNumber
connectionPort ConnectionParams
p)

-- | connect to a destination using the parameter
connectTo :: ConnectionContext -- ^ The global context of this connection.
          -> ConnectionParams  -- ^ The parameters for this connection (where to connect, and such).
          -> IO Connection     -- ^ The new established connection on success.
connectTo :: ConnectionContext -> ConnectionParams -> IO Connection
connectTo ConnectionContext
cg ConnectionParams
cParams = do
    let conFct :: IO (Socket, SockAddr)
conFct = Maybe ProxySettings
-> String -> PortNumber -> IO (Socket, SockAddr)
doConnect (ConnectionParams -> Maybe ProxySettings
connectionUseSocks ConnectionParams
cParams)
                           (ConnectionParams -> String
connectionHostname ConnectionParams
cParams)
                           (ConnectionParams -> PortNumber
connectionPort ConnectionParams
cParams)
    IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO ())
-> ((Socket, SockAddr) -> IO Connection)
-> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError IO (Socket, SockAddr)
conFct (Socket -> IO ()
close (Socket -> IO ())
-> ((Socket, SockAddr) -> Socket) -> (Socket, SockAddr) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst) (((Socket, SockAddr) -> IO Connection) -> IO Connection)
-> ((Socket, SockAddr) -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \(Socket
h, SockAddr
_) ->
        ConnectionContext -> Socket -> ConnectionParams -> IO Connection
connectFromSocket ConnectionContext
cg Socket
h ConnectionParams
cParams
  where
    sockConnect :: String
-> PortNumber -> String -> PortNumber -> IO (Socket, SockAddr)
sockConnect String
sockHost PortNumber
sockPort String
h PortNumber
p = do
        (Socket
sockServ, SockAddr
servAddr) <- String -> PortNumber -> IO (Socket, SockAddr)
resolve' String
sockHost PortNumber
sockPort
        let sockConf :: SocksConf
sockConf = SockAddr -> SocksConf
defaultSocksConf SockAddr
servAddr
        let destAddr :: SocksAddress
destAddr = SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (SessionID -> SocksHostAddress
SocksAddrDomainName (SessionID -> SocksHostAddress) -> SessionID -> SocksHostAddress
forall a b. (a -> b) -> a -> b
$ String -> SessionID
BC.pack String
h) PortNumber
p
        (SocksHostAddress
dest, PortNumber
_) <- Socket
-> SocksConf -> SocksAddress -> IO (SocksHostAddress, PortNumber)
socksConnectWithSocket Socket
sockServ SocksConf
sockConf SocksAddress
destAddr
        case SocksHostAddress
dest of
            SocksAddrIPV4 HostAddress
h4 -> (Socket, SockAddr) -> IO (Socket, SockAddr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
sockServ, PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
p HostAddress
h4)
            SocksAddrIPV6 HostAddress6
h6 -> (Socket, SockAddr) -> IO (Socket, SockAddr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
sockServ, PortNumber
-> HostAddress -> HostAddress6 -> HostAddress -> SockAddr
SockAddrInet6 PortNumber
p HostAddress
0 HostAddress6
h6 HostAddress
0)
            SocksAddrDomainName SessionID
_ -> String -> IO (Socket, SockAddr)
forall a. HasCallStack => String -> a
error String
"internal error: socks connect return a resolved address as domain name"


    doConnect :: Maybe ProxySettings
-> String -> PortNumber -> IO (Socket, SockAddr)
doConnect Maybe ProxySettings
proxy String
h PortNumber
p =
        case Maybe ProxySettings
proxy of
            Maybe ProxySettings
Nothing                 -> String -> PortNumber -> IO (Socket, SockAddr)
resolve' String
h PortNumber
p
            Just (OtherProxy String
proxyHost PortNumber
proxyPort) -> String -> PortNumber -> IO (Socket, SockAddr)
resolve' String
proxyHost PortNumber
proxyPort
            Just (SockSettingsSimple String
sockHost PortNumber
sockPort) ->
                String
-> PortNumber -> String -> PortNumber -> IO (Socket, SockAddr)
sockConnect String
sockHost PortNumber
sockPort String
h PortNumber
p
            Just (SockSettingsEnvironment Maybe String
envName) -> do
                -- if we can't get the environment variable or that the string cannot be parsed
                -- we connect directly.
                let name :: String
name = String -> ShowS -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"SOCKS_SERVER" ShowS
forall a. a -> a
id Maybe String
envName
                Either IOException String
evar <- IO String -> IO (Either IOException String)
forall e a. Exception e => IO a -> IO (Either e a)
E.try (String -> IO String
getEnv String
name)
                case Either IOException String
evar of
                    Left (IOException
_ :: E.IOException) -> String -> PortNumber -> IO (Socket, SockAddr)
resolve' String
h PortNumber
p
                    Right String
var                 ->
                        case String -> Maybe ConnectionID
parseSocks String
var of
                            Maybe ConnectionID
Nothing                   -> String -> PortNumber -> IO (Socket, SockAddr)
resolve' String
h PortNumber
p
                            Just (String
sockHost, PortNumber
sockPort) -> String
-> PortNumber -> String -> PortNumber -> IO (Socket, SockAddr)
sockConnect String
sockHost PortNumber
sockPort String
h PortNumber
p

    -- Try to parse "host:port" or "host"
    -- if port is ommited then the default SOCKS port (1080) is assumed
    parseSocks :: String -> Maybe (String, PortNumber)
    parseSocks :: String -> Maybe ConnectionID
parseSocks String
s =
        case (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':') String
s of
            (String
sHost, String
"")        -> ConnectionID -> Maybe ConnectionID
forall a. a -> Maybe a
Just (String
sHost, PortNumber
1080)
            (String
sHost, Char
':':String
portS) ->
                case ReadS PortNumber
forall a. Read a => ReadS a
reads String
portS of
                    [(PortNumber
sPort,String
"")] -> ConnectionID -> Maybe ConnectionID
forall a. a -> Maybe a
Just (String
sHost, PortNumber
sPort)
                    [(PortNumber, String)]
_            -> Maybe ConnectionID
forall a. Maybe a
Nothing
            (String, String)
_                  -> Maybe ConnectionID
forall a. Maybe a
Nothing

    -- Try to resolve the host/port into an address (zero to many of them), then
    -- try to connect from the first address to the last, returning the first one that
    -- succeed
    resolve' :: String -> PortNumber -> IO (Socket, SockAddr)
    resolve' :: String -> PortNumber -> IO (Socket, SockAddr)
resolve' String
host PortNumber
port = do
        let hints :: AddrInfo
hints = AddrInfo
defaultHints { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_ADDRCONFIG], addrSocketType :: SocketType
addrSocketType = SocketType
Stream }
        [AddrInfo]
addrs <- Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
host) (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
port)
        [IO (Socket, SockAddr)] -> IO (Socket, SockAddr)
forall a. [IO a] -> IO a
firstSuccessful ([IO (Socket, SockAddr)] -> IO (Socket, SockAddr))
-> [IO (Socket, SockAddr)] -> IO (Socket, SockAddr)
forall a b. (a -> b) -> a -> b
$ (AddrInfo -> IO (Socket, SockAddr))
-> [AddrInfo] -> [IO (Socket, SockAddr)]
forall a b. (a -> b) -> [a] -> [b]
map AddrInfo -> IO (Socket, SockAddr)
tryToConnect [AddrInfo]
addrs
      where
        tryToConnect :: AddrInfo -> IO (Socket, SockAddr)
tryToConnect AddrInfo
addr =
            IO Socket
-> (Socket -> IO ())
-> (Socket -> IO (Socket, SockAddr))
-> IO (Socket, SockAddr)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError
                (Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr))
                (Socket -> IO ()
close)
                (\Socket
sock -> Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
addr) IO () -> IO (Socket, SockAddr) -> IO (Socket, SockAddr)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Socket, SockAddr) -> IO (Socket, SockAddr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
sock, AddrInfo -> SockAddr
addrAddress AddrInfo
addr))
        firstSuccessful :: [IO a] -> IO a
firstSuccessful = [IOException] -> [IO a] -> IO a
forall a. [IOException] -> [IO a] -> IO a
go []
          where
            go :: [E.IOException] -> [IO a] -> IO a
            go :: [IOException] -> [IO a] -> IO a
go []      [] = HostNotResolved -> IO a
forall e a. Exception e => e -> IO a
E.throwIO (HostNotResolved -> IO a) -> HostNotResolved -> IO a
forall a b. (a -> b) -> a -> b
$ String -> HostNotResolved
HostNotResolved String
host
            go l :: [IOException]
l@(IOException
_:[IOException]
_) [] = HostCannotConnect -> IO a
forall e a. Exception e => e -> IO a
E.throwIO (HostCannotConnect -> IO a) -> HostCannotConnect -> IO a
forall a b. (a -> b) -> a -> b
$ String -> [IOException] -> HostCannotConnect
HostCannotConnect String
host [IOException]
l
            go [IOException]
acc     (IO a
act:[IO a]
followingActs) = do
                Either IOException a
er <- IO a -> IO (Either IOException a)
forall e a. Exception e => IO a -> IO (Either e a)
E.try IO a
act
                case Either IOException a
er of
                    Left IOException
err -> [IOException] -> [IO a] -> IO a
forall a. [IOException] -> [IO a] -> IO a
go (IOException
errIOException -> [IOException] -> [IOException]
forall a. a -> [a] -> [a]
:[IOException]
acc) [IO a]
followingActs
                    Right a
r  -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

-- | Put a block of data in the connection.
connectionPut :: Connection -> ByteString -> IO ()
connectionPut :: Connection -> SessionID -> IO ()
connectionPut Connection
connection SessionID
content = (ConnectionBackend -> IO ()) -> Connection -> IO ()
forall a. (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend ConnectionBackend -> IO ()
doWrite Connection
connection
    where doWrite :: ConnectionBackend -> IO ()
doWrite (ConnectionStream Handle
h) = Handle -> SessionID -> IO ()
B.hPut Handle
h SessionID
content IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
hFlush Handle
h
          doWrite (ConnectionSocket Socket
s) = Socket -> SessionID -> IO ()
N.sendAll Socket
s SessionID
content
          doWrite (ConnectionTLS Context
ctx)  = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ [SessionID] -> ByteString
L.fromChunks [SessionID
content]

-- | Get exact count of bytes from a connection.
--
-- The size argument is the exact amount that must be returned to the user.
-- The call will wait until all data is available.  Hence, it behaves like
-- 'B.hGet'.
--
-- On end of input, 'connectionGetExact' will throw an 'E.isEOFError'
-- exception.
connectionGetExact :: Connection -> Int -> IO ByteString
connectionGetExact :: Connection -> Int -> IO SessionID
connectionGetExact Connection
conn Int
x = SessionID -> Int -> IO SessionID
loop SessionID
B.empty Int
0
  where loop :: SessionID -> Int -> IO SessionID
loop SessionID
bs Int
y
          | Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x = SessionID -> IO SessionID
forall (m :: * -> *) a. Monad m => a -> m a
return SessionID
bs
          | Bool
otherwise = do
            SessionID
next <- Connection -> Int -> IO SessionID
connectionGet Connection
conn (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
y)
            SessionID -> Int -> IO SessionID
loop (SessionID -> SessionID -> SessionID
B.append SessionID
bs SessionID
next) (Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (SessionID -> Int
B.length SessionID
next))

-- | Get some bytes from a connection.
--
-- The size argument is just the maximum that could be returned to the user.
-- The call will return as soon as there's data, even if there's less
-- than requested.  Hence, it behaves like 'B.hGetSome'.
--
-- On end of input, 'connectionGet' returns 0, but subsequent calls will throw
-- an 'E.isEOFError' exception.
connectionGet :: Connection -> Int -> IO ByteString
connectionGet :: Connection -> Int -> IO SessionID
connectionGet Connection
conn Int
size
  | Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0  = String -> IO SessionID
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Connection.connectionGet: size < 0"
  | Int
size Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = SessionID -> IO SessionID
forall (m :: * -> *) a. Monad m => a -> m a
return SessionID
B.empty
  | Bool
otherwise = String
-> Connection
-> (SessionID -> (SessionID, SessionID))
-> IO SessionID
forall a.
String -> Connection -> (SessionID -> (a, SessionID)) -> IO a
connectionGetChunkBase String
"connectionGet" Connection
conn ((SessionID -> (SessionID, SessionID)) -> IO SessionID)
-> (SessionID -> (SessionID, SessionID)) -> IO SessionID
forall a b. (a -> b) -> a -> b
$ Int -> SessionID -> (SessionID, SessionID)
B.splitAt Int
size

-- | Get the next block of data from the connection.
connectionGetChunk :: Connection -> IO ByteString
connectionGetChunk :: Connection -> IO SessionID
connectionGetChunk Connection
conn =
    String
-> Connection
-> (SessionID -> (SessionID, SessionID))
-> IO SessionID
forall a.
String -> Connection -> (SessionID -> (a, SessionID)) -> IO a
connectionGetChunkBase String
"connectionGetChunk" Connection
conn ((SessionID -> (SessionID, SessionID)) -> IO SessionID)
-> (SessionID -> (SessionID, SessionID)) -> IO SessionID
forall a b. (a -> b) -> a -> b
$ \SessionID
s -> (SessionID
s, SessionID
B.empty)

-- | Like 'connectionGetChunk', but return the unused portion to the buffer,
-- where it will be the next chunk read.
connectionGetChunk' :: Connection -> (ByteString -> (a, ByteString)) -> IO a
connectionGetChunk' :: Connection -> (SessionID -> (a, SessionID)) -> IO a
connectionGetChunk' = String -> Connection -> (SessionID -> (a, SessionID)) -> IO a
forall a.
String -> Connection -> (SessionID -> (a, SessionID)) -> IO a
connectionGetChunkBase String
"connectionGetChunk'"

-- | Wait for input to become available on a connection.
--
-- As with 'hWaitForInput', the timeout value is given in milliseconds.  If the
-- timeout value is less than zero, then 'connectionWaitForInput' waits
-- indefinitely.
--
-- Unlike 'hWaitForInput', this function does not do any decoding, so it
-- returns true when there is /any/ available input, not just full characters.
connectionWaitForInput :: Connection -> Int -> IO Bool
connectionWaitForInput :: Connection -> Int -> IO Bool
connectionWaitForInput Connection
conn Int
timeout_ms = Bool -> (() -> Bool) -> Maybe () -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
True) (Maybe () -> Bool) -> IO (Maybe ()) -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO () -> IO (Maybe ())
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
timeout_ns IO ()
tryGetChunk
  where tryGetChunk :: IO ()
tryGetChunk = String -> Connection -> (SessionID -> ((), SessionID)) -> IO ()
forall a.
String -> Connection -> (SessionID -> (a, SessionID)) -> IO a
connectionGetChunkBase String
"connectionWaitForInput" Connection
conn ((SessionID -> ((), SessionID)) -> IO ())
-> (SessionID -> ((), SessionID)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \SessionID
buf -> ((), SessionID
buf)
        timeout_ns :: Int
timeout_ns  = Int
timeout_ms Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000

connectionGetChunkBase :: String -> Connection -> (ByteString -> (a, ByteString)) -> IO a
connectionGetChunkBase :: String -> Connection -> (SessionID -> (a, SessionID)) -> IO a
connectionGetChunkBase String
loc Connection
conn SessionID -> (a, SessionID)
f =
    MVar (Maybe SessionID)
-> (Maybe SessionID -> IO (Maybe SessionID, a)) -> IO a
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Connection -> MVar (Maybe SessionID)
connectionBuffer Connection
conn) ((Maybe SessionID -> IO (Maybe SessionID, a)) -> IO a)
-> (Maybe SessionID -> IO (Maybe SessionID, a)) -> IO a
forall a b. (a -> b) -> a -> b
$ \Maybe SessionID
m ->
        case Maybe SessionID
m of
            Maybe SessionID
Nothing -> Connection -> String -> IO (Maybe SessionID, a)
forall a. Connection -> String -> IO a
throwEOF Connection
conn String
loc
            Just SessionID
buf
              | SessionID -> Bool
B.null SessionID
buf -> do
                  SessionID
chunk <- (ConnectionBackend -> IO SessionID) -> Connection -> IO SessionID
forall a. (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend ConnectionBackend -> IO SessionID
getMoreData Connection
conn
                  if SessionID -> Bool
B.null SessionID
chunk
                     then SessionID -> IO (Maybe SessionID, a)
forall (m :: * -> *) a. Monad m => SessionID -> m (Maybe a, a)
closeBuf SessionID
chunk
                     else SessionID -> IO (Maybe SessionID, a)
forall (m :: * -> *).
Monad m =>
SessionID -> m (Maybe SessionID, a)
updateBuf SessionID
chunk
              | Bool
otherwise ->
                  SessionID -> IO (Maybe SessionID, a)
forall (m :: * -> *).
Monad m =>
SessionID -> m (Maybe SessionID, a)
updateBuf SessionID
buf
  where
    getMoreData :: ConnectionBackend -> IO SessionID
getMoreData (ConnectionTLS Context
tlsctx) = Context -> IO SessionID
forall (m :: * -> *). MonadIO m => Context -> m SessionID
TLS.recvData Context
tlsctx
    getMoreData (ConnectionSocket Socket
sock) = Socket -> Int -> IO SessionID
N.recv Socket
sock Int
1500
    getMoreData (ConnectionStream Handle
h)   = Handle -> Int -> IO SessionID
B.hGetSome Handle
h (Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024)

    updateBuf :: SessionID -> m (Maybe SessionID, a)
updateBuf SessionID
buf = case SessionID -> (a, SessionID)
f SessionID
buf of (a
a, !SessionID
buf') -> (Maybe SessionID, a) -> m (Maybe SessionID, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionID -> Maybe SessionID
forall a. a -> Maybe a
Just SessionID
buf', a
a)
    closeBuf :: SessionID -> m (Maybe a, a)
closeBuf  SessionID
buf = case SessionID -> (a, SessionID)
f SessionID
buf of (a
a, SessionID
_buf') -> (Maybe a, a) -> m (Maybe a, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a
forall a. Maybe a
Nothing, a
a)

-- | Get the next line, using ASCII LF as the line terminator.
--
-- This throws an 'isEOFError' exception on end of input, and LineTooLong when
-- the number of bytes gathered is over the limit without a line terminator.
--
-- The actual line returned can be bigger than the limit specified, provided
-- that the last chunk returned by the underlaying backend contains a LF.
-- In another world only when we need more input and limit is reached that the
-- LineTooLong exception will be raised.
--
-- An end of file will be considered as a line terminator too, if line is
-- not empty.
connectionGetLine :: Int           -- ^ Maximum number of bytes before raising a LineTooLong exception
                  -> Connection    -- ^ Connection
                  -> IO ByteString -- ^ The received line with the LF trimmed
connectionGetLine :: Int -> Connection -> IO SessionID
connectionGetLine Int
limit Connection
conn = IO Any -> Int -> ([SessionID] -> [SessionID]) -> IO SessionID
forall t. t -> Int -> ([SessionID] -> [SessionID]) -> IO SessionID
more (Connection -> String -> IO Any
forall a. Connection -> String -> IO a
throwEOF Connection
conn String
loc) Int
0 [SessionID] -> [SessionID]
forall a. a -> a
id
  where
    loc :: String
loc = String
"connectionGetLine"
    lineTooLong :: IO a
lineTooLong = LineTooLong -> IO a
forall e a. Exception e => e -> IO a
E.throwIO LineTooLong
LineTooLong

    -- Accumulate chunks using a difference list, and concatenate them
    -- when an end-of-line indicator is reached.
    more :: t -> Int -> ([SessionID] -> [SessionID]) -> IO SessionID
more t
eofK !Int
currentSz ![SessionID] -> [SessionID]
dl =
        (SessionID -> IO SessionID)
-> (SessionID -> IO SessionID) -> IO SessionID -> IO SessionID
forall r.
(SessionID -> IO r) -> (SessionID -> IO r) -> IO r -> IO r
getChunk (\SessionID
s -> let len :: Int
len = SessionID -> Int
B.length SessionID
s
                         in if Int
currentSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
limit
                               then IO SessionID
forall a. IO a
lineTooLong
                               else t -> Int -> ([SessionID] -> [SessionID]) -> IO SessionID
more t
eofK (Int
currentSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) ([SessionID] -> [SessionID]
dl ([SessionID] -> [SessionID])
-> ([SessionID] -> [SessionID]) -> [SessionID] -> [SessionID]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SessionID
sSessionID -> [SessionID] -> [SessionID]
forall a. a -> [a] -> [a]
:)))
                 (\SessionID
s -> ([SessionID] -> [SessionID]) -> IO SessionID
done ([SessionID] -> [SessionID]
dl ([SessionID] -> [SessionID])
-> ([SessionID] -> [SessionID]) -> [SessionID] -> [SessionID]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SessionID
sSessionID -> [SessionID] -> [SessionID]
forall a. a -> [a] -> [a]
:)))
                 (([SessionID] -> [SessionID]) -> IO SessionID
done [SessionID] -> [SessionID]
dl)

    done :: ([ByteString] -> [ByteString]) -> IO ByteString
    done :: ([SessionID] -> [SessionID]) -> IO SessionID
done [SessionID] -> [SessionID]
dl = SessionID -> IO SessionID
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionID -> IO SessionID) -> SessionID -> IO SessionID
forall a b. (a -> b) -> a -> b
$! [SessionID] -> SessionID
B.concat ([SessionID] -> SessionID) -> [SessionID] -> SessionID
forall a b. (a -> b) -> a -> b
$ [SessionID] -> [SessionID]
dl []

    -- Get another chunk, and call one of the continuations
    getChunk :: (ByteString -> IO r) -- moreK: need more input
             -> (ByteString -> IO r) -- doneK: end of line (line terminator found)
             -> IO r                 -- eofK:  end of file
             -> IO r
    getChunk :: (SessionID -> IO r) -> (SessionID -> IO r) -> IO r -> IO r
getChunk SessionID -> IO r
moreK SessionID -> IO r
doneK IO r
eofK =
      IO (IO r) -> IO r
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO r) -> IO r) -> IO (IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ String
-> Connection -> (SessionID -> (IO r, SessionID)) -> IO (IO r)
forall a.
String -> Connection -> (SessionID -> (a, SessionID)) -> IO a
connectionGetChunkBase String
loc Connection
conn ((SessionID -> (IO r, SessionID)) -> IO (IO r))
-> (SessionID -> (IO r, SessionID)) -> IO (IO r)
forall a b. (a -> b) -> a -> b
$ \SessionID
s ->
        if SessionID -> Bool
B.null SessionID
s
          then (IO r
eofK, SessionID
B.empty)
          else case (Word8 -> Bool) -> SessionID -> (SessionID, SessionID)
B.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
10) SessionID
s of
                 (SessionID
a, SessionID
b)
                   | SessionID -> Bool
B.null SessionID
b  -> (SessionID -> IO r
moreK SessionID
a, SessionID
B.empty)
                   | Bool
otherwise -> (SessionID -> IO r
doneK SessionID
a, SessionID -> SessionID
B.tail SessionID
b)

throwEOF :: Connection -> String -> IO a
throwEOF :: Connection -> String -> IO a
throwEOF Connection
conn String
loc =
    IOException -> IO a
forall e a. Exception e => e -> IO a
E.throwIO (IOException -> IO a) -> IOException -> IO a
forall a b. (a -> b) -> a -> b
$ IOErrorType
-> String -> Maybe Handle -> Maybe String -> IOException
E.mkIOError IOErrorType
E.eofErrorType String
loc' Maybe Handle
forall a. Maybe a
Nothing (String -> Maybe String
forall a. a -> Maybe a
Just String
path)
  where
    loc' :: String
loc' = String
"Network.Connection." String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
loc
    path :: String
path = let (String
host, PortNumber
port) = Connection -> ConnectionID
connectionID Connection
conn
            in String
host String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
":" String -> ShowS
forall a. [a] -> [a] -> [a]
++ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
port

-- | Close a connection.
connectionClose :: Connection -> IO ()
connectionClose :: Connection -> IO ()
connectionClose = (ConnectionBackend -> IO ()) -> Connection -> IO ()
forall a. (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend ConnectionBackend -> IO ()
backendClose
    where backendClose :: ConnectionBackend -> IO ()
backendClose (ConnectionTLS Context
ctx)  = IO () -> IO ()
ignoreIOExc (Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx) IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`E.finally` Context -> IO ()
TLS.contextClose Context
ctx
          backendClose (ConnectionSocket Socket
sock) = Socket -> IO ()
close Socket
sock
          backendClose (ConnectionStream Handle
h) = Handle -> IO ()
hClose Handle
h

          ignoreIOExc :: IO () -> IO ()
ignoreIOExc IO ()
action = IO ()
action IO () -> (IOException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \(IOException
_ :: E.IOException) -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Activate secure layer using the parameters specified.
--
-- This is typically used to negociate a TLS channel on an already
-- establish channel, e.g. supporting a STARTTLS command. it also
-- flush the received buffer to prevent application confusing
-- received data before and after the setSecure call.
--
-- If the connection is already using TLS, nothing else happens.
connectionSetSecure :: ConnectionContext
                    -> Connection
                    -> TLSSettings
                    -> IO ()
connectionSetSecure :: ConnectionContext -> Connection -> TLSSettings -> IO ()
connectionSetSecure ConnectionContext
cg Connection
connection TLSSettings
params =
    MVar (Maybe SessionID)
-> (Maybe SessionID -> IO (Maybe SessionID)) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Connection -> MVar (Maybe SessionID)
connectionBuffer Connection
connection) ((Maybe SessionID -> IO (Maybe SessionID)) -> IO ())
-> (Maybe SessionID -> IO (Maybe SessionID)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Maybe SessionID
b ->
    MVar ConnectionBackend
-> (ConnectionBackend -> IO (ConnectionBackend, Maybe SessionID))
-> IO (Maybe SessionID)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Connection -> MVar ConnectionBackend
connectionBackend Connection
connection) ((ConnectionBackend -> IO (ConnectionBackend, Maybe SessionID))
 -> IO (Maybe SessionID))
-> (ConnectionBackend -> IO (ConnectionBackend, Maybe SessionID))
-> IO (Maybe SessionID)
forall a b. (a -> b) -> a -> b
$ \ConnectionBackend
backend ->
        case ConnectionBackend
backend of
            (ConnectionStream Handle
h) -> do Context
ctx <- Handle -> ClientParams -> IO Context
forall backend.
HasBackend backend =>
backend -> ClientParams -> IO Context
tlsEstablish Handle
h (ConnectionContext -> ConnectionID -> TLSSettings -> ClientParams
makeTLSParams ConnectionContext
cg (Connection -> ConnectionID
connectionID Connection
connection) TLSSettings
params)
                                       (ConnectionBackend, Maybe SessionID)
-> IO (ConnectionBackend, Maybe SessionID)
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> ConnectionBackend
ConnectionTLS Context
ctx, SessionID -> Maybe SessionID
forall a. a -> Maybe a
Just SessionID
B.empty)
            (ConnectionSocket Socket
s) -> do Context
ctx <- Socket -> ClientParams -> IO Context
forall backend.
HasBackend backend =>
backend -> ClientParams -> IO Context
tlsEstablish Socket
s (ConnectionContext -> ConnectionID -> TLSSettings -> ClientParams
makeTLSParams ConnectionContext
cg (Connection -> ConnectionID
connectionID Connection
connection) TLSSettings
params)
                                       (ConnectionBackend, Maybe SessionID)
-> IO (ConnectionBackend, Maybe SessionID)
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> ConnectionBackend
ConnectionTLS Context
ctx, SessionID -> Maybe SessionID
forall a. a -> Maybe a
Just SessionID
B.empty)
            (ConnectionTLS Context
_)    -> (ConnectionBackend, Maybe SessionID)
-> IO (ConnectionBackend, Maybe SessionID)
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectionBackend
backend, Maybe SessionID
b)

-- | Returns if the connection is establish securely or not.
connectionIsSecure :: Connection -> IO Bool
connectionIsSecure :: Connection -> IO Bool
connectionIsSecure Connection
conn = (ConnectionBackend -> IO Bool) -> Connection -> IO Bool
forall a. (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend ConnectionBackend -> IO Bool
forall (m :: * -> *). Monad m => ConnectionBackend -> m Bool
isSecure Connection
conn
    where isSecure :: ConnectionBackend -> m Bool
isSecure (ConnectionStream Handle
_) = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
          isSecure (ConnectionSocket Socket
_) = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
          isSecure (ConnectionTLS Context
_)    = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

tlsEstablish :: TLS.HasBackend backend => backend -> TLS.ClientParams -> IO TLS.Context
tlsEstablish :: backend -> ClientParams -> IO Context
tlsEstablish backend
handle ClientParams
tlsParams = do
    Context
ctx <- backend -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew backend
handle ClientParams
tlsParams
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
    Context -> IO Context
forall (m :: * -> *) a. Monad m => a -> m a
return Context
ctx