6060#include "aes_icm_ext.h"
6161#endif
6262
63+ #include <stddef.h>
64+ #include <string.h>
6365#include <limits.h>
6466#ifdef HAVE_NETINET_IN_H
6567#include <netinet/in.h>
6668#elif defined(HAVE_WINSOCK2_H )
6769#include <winsock2.h>
6870#endif
6971
72+ #if defined(__SSE2__ )
73+ #include <emmintrin.h>
74+ #if defined(_MSC_VER )
75+ #include <intrin.h>
76+ #endif
77+ #endif
78+
7079/* the debug module for srtp */
7180srtp_debug_module_t mod_srtp = {
7281 0 , /* debugging is off by default */
@@ -79,6 +88,17 @@ srtp_debug_module_t mod_srtp = {
7988#define uint32s_in_rtcp_header 2
8089#define octets_in_rtp_extn_hdr 4
8190
91+ #ifndef SRTP_NO_STREAM_LIST
92+ static inline uint32_t srtp_stream_list_size (srtp_stream_list_t list );
93+ static srtp_err_status_t srtp_stream_list_reserve (srtp_stream_list_t list ,
94+ uint32_t new_capacity );
95+ static uint32_t srtp_stream_list_find (srtp_stream_list_t list ,
96+ uint32_t ssrc );
97+ static inline srtp_stream_t srtp_stream_list_get_at (srtp_stream_list_t list ,
98+ uint32_t pos );
99+ static void srtp_stream_list_remove_at (srtp_stream_list_t list , uint32_t pos );
100+ #endif // SRTP_NO_STREAM_LIST
101+
82102static srtp_err_status_t srtp_validate_rtp_header (void * rtp_hdr ,
83103 int * pkt_octet_len )
84104{
@@ -3030,18 +3050,31 @@ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc)
30303050{
30313051 srtp_stream_ctx_t * stream ;
30323052 srtp_err_status_t status ;
3053+ #if !defined(SRTP_NO_STREAM_LIST )
3054+ uint32_t pos ;
3055+ #endif
30333056
30343057 /* sanity check arguments */
3035- if (session == NULL )
3058+ if (session == NULL ) {
30363059 return srtp_err_status_bad_param ;
3060+ }
30373061
30383062 /* find and remove stream from the list */
3063+ #if !defined(SRTP_NO_STREAM_LIST )
3064+ pos = srtp_stream_list_find (session -> stream_list , ssrc );
3065+ if (pos >= srtp_stream_list_size (session -> stream_list ))
3066+ return srtp_err_status_no_ctx ;
3067+
3068+ stream = srtp_stream_list_get_at (session -> stream_list , pos );
3069+ srtp_stream_list_remove_at (session -> stream_list , pos );
3070+ #else
30393071 stream = srtp_stream_list_get (session -> stream_list , ssrc );
30403072 if (stream == NULL ) {
30413073 return srtp_err_status_no_ctx ;
30423074 }
30433075
30443076 srtp_stream_list_remove (session -> stream_list , stream );
3077+ #endif
30453078
30463079 /* deallocate the stream */
30473080 status = srtp_stream_dealloc (stream , session -> stream_template );
@@ -4840,11 +4873,11 @@ srtp_err_status_t srtp_get_stream_roc(srtp_t session,
48404873
48414874#ifndef SRTP_NO_STREAM_LIST
48424875
4843- /* in the default implementation, we have an intrusive doubly-linked list */
48444876typedef struct srtp_stream_list_ctx_t_ {
4845- /* a stub stream that just holds pointers to the beginning and end of the
4846- * list */
4847- srtp_stream_ctx_t data ;
4877+ uint32_t * ssrcs ;
4878+ srtp_stream_ctx_t * * streams ;
4879+ uint32_t size ;
4880+ uint32_t capacity ;
48484881} srtp_stream_list_ctx_t_ ;
48494882
48504883srtp_err_status_t srtp_stream_list_alloc (srtp_stream_list_t * list_ptr )
@@ -4855,73 +4888,204 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr)
48554888 return srtp_err_status_alloc_fail ;
48564889 }
48574890
4858- list -> data .next = NULL ;
4859- list -> data .prev = NULL ;
4860-
48614891 * list_ptr = list ;
48624892 return srtp_err_status_ok ;
48634893}
48644894
48654895srtp_err_status_t srtp_stream_list_dealloc (srtp_stream_list_t list )
48664896{
48674897 /* list must be empty */
4868- if (list -> data . next ) {
4898+ if (list -> size != 0u ) {
48694899 return srtp_err_status_fail ;
48704900 }
4901+ srtp_crypto_free (list -> streams );
4902+ srtp_crypto_free (list -> ssrcs );
48714903 srtp_crypto_free (list );
48724904 return srtp_err_status_ok ;
48734905}
48744906
4907+ static inline uint32_t srtp_stream_list_size (srtp_stream_list_t list )
4908+ {
4909+ return list -> size ;
4910+ }
4911+
4912+ static srtp_err_status_t srtp_stream_list_reserve (srtp_stream_list_t list ,
4913+ uint32_t new_capacity )
4914+ {
4915+ if (new_capacity > list -> capacity ) {
4916+ uint32_t * ssrcs ;
4917+ srtp_stream_ctx_t * * stream_ptrs ;
4918+
4919+ if (new_capacity > (UINT32_MAX - 15u ))
4920+ return srtp_err_status_alloc_fail ;
4921+
4922+ new_capacity = (new_capacity + 15u ) & ~((uint32_t )15u );
4923+
4924+ ssrcs = (uint32_t * )srtp_crypto_alloc ((size_t )new_capacity *
4925+ sizeof (uint32_t ));
4926+ if (!ssrcs )
4927+ return srtp_err_status_alloc_fail ;
4928+ stream_ptrs = (srtp_stream_ctx_t * * )srtp_crypto_alloc (
4929+ (size_t )new_capacity * sizeof (srtp_stream_ctx_t * ));
4930+ if (!stream_ptrs ) {
4931+ srtp_crypto_free (ssrcs );
4932+ return srtp_err_status_alloc_fail ;
4933+ }
4934+
4935+ if (list -> size > 0u ) {
4936+ memcpy (ssrcs , list -> ssrcs , (size_t )list -> size * sizeof (uint32_t ));
4937+ memcpy (stream_ptrs , list -> streams ,
4938+ (size_t )list -> size * sizeof (srtp_stream_ctx_t * ));
4939+ }
4940+
4941+ srtp_crypto_free (list -> ssrcs );
4942+ srtp_crypto_free (list -> streams );
4943+ list -> streams = stream_ptrs ;
4944+ list -> ssrcs = ssrcs ;
4945+
4946+ list -> capacity = new_capacity ;
4947+ }
4948+
4949+ return srtp_err_status_ok ;
4950+ }
4951+
48754952srtp_err_status_t srtp_stream_list_insert (srtp_stream_list_t list ,
48764953 srtp_stream_t stream )
48774954{
4878- /* insert at the head of the list */
4879- stream -> next = list -> data . next ;
4880- if (stream -> next != NULL ) {
4881- stream -> next -> prev = stream ;
4882- }
4883- list -> data . next = stream ;
4884- stream -> prev = & ( list -> data ) ;
4955+ uint32_t pos ;
4956+ srtp_err_status_t status = srtp_stream_list_reserve ( list , list -> size + 1u ) ;
4957+ if (status )
4958+ return status ;
4959+ pos = list -> size ++ ;
4960+ list -> ssrcs [ pos ] = stream -> ssrc ;
4961+ list -> streams [ pos ] = stream ;
48854962
48864963 return srtp_err_status_ok ;
48874964}
48884965
4889- srtp_stream_t srtp_stream_list_get (srtp_stream_list_t list , uint32_t ssrc )
4966+ static uint32_t srtp_stream_list_find (srtp_stream_list_t list , uint32_t ssrc )
48904967{
4891- /* walk down list until ssrc is found */
4892- srtp_stream_t stream = list -> data .next ;
4893- while (stream != NULL ) {
4894- if (stream -> ssrc == ssrc ) {
4895- return stream ;
4968+ #if defined(__SSE2__ )
4969+ const uint32_t * const ssrcs = list -> ssrcs ;
4970+ const __m128i mm_ssrc = _mm_set1_epi32 (ssrc );
4971+ uint32_t pos = 0u , n = (list -> size + 7u ) & ~(uint32_t )(7u );
4972+ for (uint32_t m = n & ~(uint32_t )(15u ); pos < m ; pos += 16u ) {
4973+ __m128i mm1 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos ));
4974+ __m128i mm2 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 4u ));
4975+ __m128i mm3 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 8u ));
4976+ __m128i mm4 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 12u ));
4977+ mm1 = _mm_cmpeq_epi32 (mm1 , mm_ssrc );
4978+ mm2 = _mm_cmpeq_epi32 (mm2 , mm_ssrc );
4979+ mm3 = _mm_cmpeq_epi32 (mm3 , mm_ssrc );
4980+ mm4 = _mm_cmpeq_epi32 (mm4 , mm_ssrc );
4981+ mm1 = _mm_packs_epi32 (mm1 , mm2 );
4982+ mm3 = _mm_packs_epi32 (mm3 , mm4 );
4983+ mm1 = _mm_packs_epi16 (mm1 , mm3 );
4984+ uint32_t mask = _mm_movemask_epi8 (mm1 );
4985+ if (mask ) {
4986+ #if defined(_MSC_VER )
4987+ unsigned long bit_pos ;
4988+ _BitScanForward (& bit_pos , mask );
4989+ pos += bit_pos ;
4990+ #else
4991+ pos += __builtin_ctz (mask );
4992+ #endif
4993+
4994+ goto done ;
4995+ }
4996+ }
4997+
4998+ if (pos < n ) {
4999+ __m128i mm1 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos ));
5000+ __m128i mm2 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 4u ));
5001+ mm1 = _mm_cmpeq_epi32 (mm1 , mm_ssrc );
5002+ mm2 = _mm_cmpeq_epi32 (mm2 , mm_ssrc );
5003+ mm1 = _mm_packs_epi32 (mm1 , mm2 );
5004+
5005+ uint32_t mask = _mm_movemask_epi8 (mm1 );
5006+ if (mask ) {
5007+ #if defined(_MSC_VER )
5008+ unsigned long bit_pos ;
5009+ _BitScanForward (& bit_pos , mask );
5010+ pos += bit_pos / 2u ;
5011+ #else
5012+ pos += __builtin_ctz (mask ) / 2u ;
5013+ #endif
5014+ goto done ;
48965015 }
4897- stream = stream -> next ;
5016+
5017+ pos += 8u ;
5018+ }
5019+
5020+ done :
5021+ return pos ;
5022+ #else
5023+ /* walk down list until ssrc is found */
5024+ uint32_t pos = 0u , n = list -> size ;
5025+ for (; pos < n ; ++ pos ) {
5026+ if (list -> ssrcs [pos ] == ssrc )
5027+ break ;
48985028 }
48995029
5030+ return pos ;
5031+ #endif
5032+ }
5033+
5034+ static inline srtp_stream_t srtp_stream_list_get_at (srtp_stream_list_t list ,
5035+ uint32_t pos )
5036+ {
5037+ return list -> streams [pos ];
5038+ }
5039+
5040+ srtp_stream_t srtp_stream_list_get (srtp_stream_list_t list , uint32_t ssrc )
5041+ {
5042+ uint32_t pos = srtp_stream_list_find (list , ssrc );
5043+ if (pos < list -> size )
5044+ return list -> streams [pos ];
5045+
49005046 /* we haven't found our ssrc, so return a null */
49015047 return NULL ;
49025048}
49035049
4904- void srtp_stream_list_remove (srtp_stream_list_t list ,
4905- srtp_stream_t stream_to_remove )
5050+ static void srtp_stream_list_remove_at (srtp_stream_list_t list , uint32_t pos )
49065051{
4907- ( void ) list ;
5052+ uint32_t tail_size , last_pos ;
49085053
4909- stream_to_remove -> prev -> next = stream_to_remove -> next ;
4910- if (stream_to_remove -> next != NULL ) {
4911- stream_to_remove -> next -> prev = stream_to_remove -> prev ;
5054+ last_pos = -- list -> size ;
5055+ tail_size = last_pos - pos ;
5056+ if (tail_size > 0u ) {
5057+ memmove (list -> streams + pos , list -> streams + pos + 1 ,
5058+ (size_t )tail_size * sizeof (* list -> streams ));
5059+ memmove (list -> ssrcs + pos , list -> ssrcs + pos + 1 ,
5060+ (size_t )tail_size * sizeof (* list -> ssrcs ));
49125061 }
5062+
5063+ list -> streams [last_pos ] = NULL ;
5064+ list -> ssrcs [last_pos ] = 0u ;
5065+ }
5066+
5067+ void srtp_stream_list_remove (srtp_stream_list_t list ,
5068+ srtp_stream_t stream_to_remove )
5069+ {
5070+ uint32_t pos = srtp_stream_list_find (list , stream_to_remove -> ssrc );
5071+ if (pos < list -> size )
5072+ srtp_stream_list_remove_at (list , pos );
49135073}
49145074
49155075void srtp_stream_list_for_each (srtp_stream_list_t list ,
49165076 int (* callback )(srtp_stream_t , void * ),
49175077 void * data )
49185078{
4919- srtp_stream_t stream = list -> data .next ;
4920- while (stream != NULL ) {
4921- srtp_stream_t tmp = stream ;
4922- stream = stream -> next ;
4923- if (callback (tmp , data ))
5079+ uint32_t size = list -> size ;
5080+ for (uint32_t i = 0u ; i < size ;) {
5081+ if (callback (list -> streams [i ], data ))
49245082 break ;
5083+
5084+ /* check if the callback removed the current element */
5085+ if (size == list -> size )
5086+ ++ i ;
5087+ else
5088+ size = list -> size ;
49255089 }
49265090}
49275091
0 commit comments