Skip to content

Commit 2ccb4aa

Browse files
Merge pull request #575 from kazu-yamamoto/fix-win-fds
CmsgIdFd -> CmsgIdFds
2 parents 04b1943 + aeab895 commit 2ccb4aa

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

Network/Socket/Win32/Cmsg.hsc

+32-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
21
{-# LANGUAGE AllowAmbiguousTypes #-}
32
{-# LANGUAGE CPP #-}
3+
{-# LANGUAGE FlexibleInstances #-}
44
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
5+
{-# LANGUAGE OverloadedStrings #-}
56
{-# LANGUAGE PatternSynonyms #-}
67
{-# LANGUAGE RecordWildCards #-}
78
{-# LANGUAGE ScopedTypeVariables #-}
@@ -70,8 +71,8 @@ pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO)
7071
-- | Control message ID for POSIX file-descriptor passing.
7172
--
7273
-- Not supported on Windows; use WSADuplicateSocket instead
73-
pattern CmsgIdFd :: CmsgId
74-
pattern CmsgIdFd = CmsgId (-1) (-1)
74+
pattern CmsgIdFds :: CmsgId
75+
pattern CmsgIdFds = CmsgId (-1) (-1)
7576

7677
----------------------------------------------------------------
7778

@@ -91,11 +92,13 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs
9192
----------------------------------------------------------------
9293

9394
-- | A class to encode and decode control message.
94-
class Storable a => ControlMessage a where
95+
class ControlMessage a where
9596
controlMessageId :: CmsgId
97+
encodeCmsg :: a -> Cmsg
98+
decodeCmsg :: Cmsg -> Maybe a
9699

97-
encodeCmsg :: forall a. ControlMessage a => a -> Cmsg
98-
encodeCmsg x = unsafeDupablePerformIO $ do
100+
encodeStorableCmsg :: forall a. (ControlMessage a, Storable a) => a -> Cmsg
101+
encodeStorableCmsg x = unsafeDupablePerformIO $ do
99102
bs <- create siz $ \p0 -> do
100103
let p = castPtr p0
101104
poke p x
@@ -104,8 +107,8 @@ encodeCmsg x = unsafeDupablePerformIO $ do
104107
where
105108
siz = sizeOf x
106109

107-
decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
108-
decodeCmsg (Cmsg cmsid (PS fptr off len))
110+
decodeStorableCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
111+
decodeStorableCmsg (Cmsg cmsid (PS fptr off len))
109112
| cid /= cmsid = Nothing
110113
| len < siz = Nothing
111114
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
@@ -122,6 +125,8 @@ newtype IPv4TTL = IPv4TTL DWORD deriving (Eq, Show, Storable)
122125

123126
instance ControlMessage IPv4TTL where
124127
controlMessageId = CmsgIdIPv4TTL
128+
decodeCmsg = decodeStorableCmsg
129+
encodeCmsg = encodeStorableCmsg
125130

126131
----------------------------------------------------------------
127132

@@ -130,6 +135,8 @@ newtype IPv6HopLimit = IPv6HopLimit DWORD deriving (Eq, Show, Storable)
130135

131136
instance ControlMessage IPv6HopLimit where
132137
controlMessageId = CmsgIdIPv6HopLimit
138+
encodeCmsg = encodeStorableCmsg
139+
decodeCmsg = decodeStorableCmsg
133140

134141
----------------------------------------------------------------
135142

@@ -138,6 +145,8 @@ newtype IPv4TOS = IPv4TOS DWORD deriving (Eq, Show, Storable)
138145

139146
instance ControlMessage IPv4TOS where
140147
controlMessageId = CmsgIdIPv4TOS
148+
encodeCmsg = encodeStorableCmsg
149+
decodeCmsg = decodeStorableCmsg
141150

142151
----------------------------------------------------------------
143152

@@ -146,6 +155,8 @@ newtype IPv6TClass = IPv6TClass DWORD deriving (Eq, Show, Storable)
146155

147156
instance ControlMessage IPv6TClass where
148157
controlMessageId = CmsgIdIPv6TClass
158+
encodeCmsg = encodeStorableCmsg
159+
decodeCmsg = decodeStorableCmsg
149160

150161
----------------------------------------------------------------
151162

@@ -158,6 +169,8 @@ instance Show IPv4PktInfo where
158169

159170
instance ControlMessage IPv4PktInfo where
160171
controlMessageId = CmsgIdIPv4PktInfo
172+
encodeCmsg = encodeStorableCmsg
173+
decodeCmsg = decodeStorableCmsg
161174

162175
instance Storable IPv4PktInfo where
163176
sizeOf ~_ = #{size IN_PKTINFO}
@@ -180,6 +193,8 @@ instance Show IPv6PktInfo where
180193

181194
instance ControlMessage IPv6PktInfo where
182195
controlMessageId = CmsgIdIPv6PktInfo
196+
decodeCmsg = decodeStorableCmsg
197+
encodeCmsg = encodeStorableCmsg
183198

184199
instance Storable IPv6PktInfo where
185200
sizeOf ~_ = #{size IN6_PKTINFO}
@@ -192,8 +207,14 @@ instance Storable IPv6PktInfo where
192207
n :: ULONG <- (#peek IN6_PKTINFO, ipi6_ifindex) p
193208
return $ IPv6PktInfo (fromIntegral n) ha6
194209

195-
instance ControlMessage Fd where
196-
controlMessageId = CmsgIdFd
210+
----------------------------------------------------------------
211+
212+
instance ControlMessage [Fd] where
213+
controlMessageId = CmsgIdFds
214+
encodeCmsg = \_ -> Cmsg CmsgIdFds ""
215+
decodeCmsg = \_ -> Just []
216+
217+
----------------------------------------------------------------
197218

198219
cmsgIdBijection :: Bijection CmsgId String
199220
cmsgIdBijection =
@@ -204,7 +225,7 @@ cmsgIdBijection =
204225
, (CmsgIdIPv6TClass, "CmsgIdIPv6TClass")
205226
, (CmsgIdIPv4PktInfo, "CmsgIdIPv4PktInfo")
206227
, (CmsgIdIPv6PktInfo, "CmsgIdIPv6PktInfo")
207-
, (CmsgIdFd, "CmsgIdFd")
228+
, (CmsgIdFds, "CmsgIdFds")
208229
]
209230

210231
instance Show CmsgId where

network.cabal

+3
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,6 @@ test-suite spec
207207

208208
if impl(ghc >=8)
209209
default-extensions: Strict StrictData
210+
211+
if os(windows)
212+
cpp-options: -D_WIN32

tests/Network/SocketSpec.hs

+2
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,11 @@ spec = do
379379
let msgid = CmsgId (-300) (-300) in
380380
show msgid `shouldBe` "CmsgId (-300) (-300)"
381381

382+
#if !defined(_WIN32)
382383
describe "bijective encodeCmsg-decodeCmsg roundtrip equality" $ do
383384
it "holds for [Fd]" $ forAll genFds $
384385
\x -> (decodeCmsg . encodeCmsg $ x) == Just (x :: [Fd])
386+
#endif
385387

386388
describe "bijective read-show roundtrip equality" $ do
387389
it "holds for Family" $ forAll familyGen $

0 commit comments

Comments
 (0)