module Network.WebSockets.Hybi13
( headerVersions
, finishRequest
, finishResponse
, encodeMessages
, decodeMessages
, createRequest
, encodeFrame
) where
import qualified Blaze.ByteString.Builder as B
import Control.Applicative (pure, (<$>))
import Control.Exception (throw)
import Control.Monad (liftM)
import Data.Attoparsec (anyWord8)
import qualified Data.Attoparsec as A
import Data.Binary.Get (getWord16be,
getWord64be, runGet)
import Data.Bits ((.&.), (.|.))
import Data.ByteString (ByteString)
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Char8 ()
import qualified Data.ByteString.Lazy as BL
import Data.Digest.Pure.SHA (bytestringDigest, sha1)
import Data.Int (Int64)
import Data.IORef
import Data.Monoid (mappend, mconcat,
mempty)
import Data.Tuple (swap)
import System.Entropy as R
import qualified System.IO.Streams as Streams
import qualified System.IO.Streams.Attoparsec as Streams
import System.Random (RandomGen, newStdGen)
import Network.WebSockets.Http
import Network.WebSockets.Hybi13.Demultiplex
import Network.WebSockets.Hybi13.Mask
import Network.WebSockets.Types
headerVersions :: [ByteString]
headerVersions = ["13"]
finishRequest :: RequestHead
-> Response
finishRequest reqHttp =
let !key = getRequestHeader reqHttp "Sec-WebSocket-Key"
!hash = hashKey key
!encoded = B64.encode hash
in response101 [("Sec-WebSocket-Accept", encoded)] ""
finishResponse :: RequestHead
-> ResponseHead
-> Response
finishResponse request response
| responseCode response /= 101 = throw $ MalformedResponse response
"Wrong response status or message."
| responseHash /= challengeHash = throw $ MalformedResponse response
"Challenge and response hashes do not match."
| otherwise =
Response response ""
where
key = getRequestHeader request "Sec-WebSocket-Key"
responseHash = getResponseHeader response "Sec-WebSocket-Accept"
challengeHash = B64.encode $ hashKey key
encodeMessage :: RandomGen g => ConnectionType -> g -> Message -> (g, B.Builder)
encodeMessage conType gen msg = (gen', builder `mappend` B.flush)
where
mkFrame = Frame True False False False
(mask, gen') = case conType of
ServerConnection -> (Nothing, gen)
ClientConnection -> randomMask gen
builder = encodeFrame mask $ case msg of
(ControlMessage (Close pl)) -> mkFrame CloseFrame pl
(ControlMessage (Ping pl)) -> mkFrame PingFrame pl
(ControlMessage (Pong pl)) -> mkFrame PongFrame pl
(DataMessage (Text pl)) -> mkFrame TextFrame pl
(DataMessage (Binary pl)) -> mkFrame BinaryFrame pl
encodeMessages :: ConnectionType
-> Streams.OutputStream B.Builder
-> IO (Streams.OutputStream Message)
encodeMessages conType bStream = do
genRef <- newIORef =<< newStdGen
Streams.lockingOutputStream =<< Streams.makeOutputStream (next genRef)
where
next :: RandomGen g => IORef g -> Maybe Message -> IO ()
next _ Nothing = return ()
next genRef (Just msg) = do
build <- atomicModifyIORef genRef $ \s -> encodeMessage conType s msg
Streams.write (Just build) bStream
encodeFrame :: Mask -> Frame -> B.Builder
encodeFrame mask f = B.fromWord8 byte0 `mappend`
B.fromWord8 byte1 `mappend` len `mappend` maskbytes `mappend`
B.fromLazyByteString (maskPayload mask (framePayload f))
where
byte0 = fin .|. rsv1 .|. rsv2 .|. rsv3 .|. opcode
fin = if frameFin f then 0x80 else 0x00
rsv1 = if frameRsv1 f then 0x40 else 0x00
rsv2 = if frameRsv2 f then 0x20 else 0x00
rsv3 = if frameRsv3 f then 0x10 else 0x00
opcode = case frameType f of
ContinuationFrame -> 0x00
TextFrame -> 0x01
BinaryFrame -> 0x02
CloseFrame -> 0x08
PingFrame -> 0x09
PongFrame -> 0x0a
(maskflag, maskbytes) = case mask of
Nothing -> (0x00, mempty)
Just m -> (0x80, B.fromByteString m)
byte1 = maskflag .|. lenflag
len' = BL.length (framePayload f)
(lenflag, len)
| len' < 126 = (fromIntegral len', mempty)
| len' < 0x10000 = (126, B.fromWord16be (fromIntegral len'))
| otherwise = (127, B.fromWord64be (fromIntegral len'))
decodeMessages :: Streams.InputStream ByteString
-> IO (Streams.InputStream Message)
decodeMessages bsStream = do
dmRef <- newIORef emptyDemultiplexState
Streams.makeInputStream $ next dmRef
where
next dmRef = do
frame <- Streams.parseFromStream parseFrame bsStream
m <- atomicModifyIORef dmRef $ \s -> swap $ demultiplex s frame
maybe (next dmRef) (return . Just) m
parseFrame :: A.Parser Frame
parseFrame = do
byte0 <- anyWord8
let fin = byte0 .&. 0x80 == 0x80
rsv1 = byte0 .&. 0x40 == 0x40
rsv2 = byte0 .&. 0x20 == 0x20
rsv3 = byte0 .&. 0x10 == 0x10
opcode = byte0 .&. 0x0f
let ft = case opcode of
0x00 -> ContinuationFrame
0x01 -> TextFrame
0x02 -> BinaryFrame
0x08 -> CloseFrame
0x09 -> PingFrame
0x0a -> PongFrame
_ -> error "Unknown opcode"
byte1 <- anyWord8
let mask = byte1 .&. 0x80 == 0x80
lenflag = fromIntegral (byte1 .&. 0x7f)
len <- case lenflag of
126 -> fromIntegral . runGet' getWord16be <$> A.take 2
127 -> fromIntegral . runGet' getWord64be <$> A.take 8
_ -> return lenflag
masker <- maskPayload <$> if mask then Just <$> A.take 4 else pure Nothing
chunks <- take64 len
return $ Frame fin rsv1 rsv2 rsv3 ft (masker $ BL.fromChunks chunks)
where
runGet' g = runGet g . BL.fromChunks . return
take64 :: Int64 -> A.Parser [ByteString]
take64 n
| n <= 0 = return []
| otherwise = do
let n' = min intMax n
chunk <- A.take (fromIntegral n')
(chunk :) <$> take64 (n n')
where
intMax :: Int64
intMax = fromIntegral (maxBound :: Int)
hashKey :: ByteString -> ByteString
hashKey key = unlazy $ bytestringDigest $ sha1 $ lazy $ key `mappend` guid
where
guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
lazy = BL.fromChunks . return
unlazy = mconcat . BL.toChunks
createRequest :: ByteString
-> ByteString
-> Bool
-> Headers
-> IO RequestHead
createRequest hostname path secure customHeaders = do
key <- B64.encode `liftM` getEntropy 16
return $ RequestHead path (headers key ++ customHeaders) secure
where
headers key =
[ ("Host" , hostname )
, ("Connection" , "Upgrade" )
, ("Upgrade" , "websocket" )
, ("Sec-WebSocket-Key" , key )
, ("Sec-WebSocket-Version" , versionNumber)
]
versionNumber = head headerVersions