6060#include "aes_icm_ext.h"
6161#endif
6262
63+ #include <stddef.h>
64+ #include <string.h>
6365#include <limits.h>
6466#ifdef HAVE_NETINET_IN_H
6567#include <netinet/in.h>
6668#elif defined(HAVE_WINSOCK2_H )
6769#include <winsock2.h>
6870#endif
6971
72+ #if defined(__SSE2__ )
73+ #include <emmintrin.h>
74+ #if defined(_MSC_VER )
75+ #include <intrin.h>
76+ #endif
77+ #endif
78+
7079/* the debug module for srtp */
7180srtp_debug_module_t mod_srtp = {
7281 0 , /* debugging is off by default */
@@ -79,6 +88,16 @@ srtp_debug_module_t mod_srtp = {
7988#define uint32s_in_rtcp_header 2
8089#define octets_in_rtp_extn_hdr 4
8190
91+ #ifndef SRTP_NO_STREAM_LIST
92+ static inline uint32_t srtp_stream_list_size (srtp_stream_list_t list );
93+ static srtp_err_status_t srtp_stream_list_reserve (srtp_stream_list_t list ,
94+ uint32_t new_capacity );
95+ static uint32_t srtp_stream_list_find (srtp_stream_list_t list , uint32_t ssrc );
96+ static inline srtp_stream_t srtp_stream_list_get_at (srtp_stream_list_t list ,
97+ uint32_t pos );
98+ static void srtp_stream_list_remove_at (srtp_stream_list_t list , uint32_t pos );
99+ #endif // SRTP_NO_STREAM_LIST
100+
82101static srtp_err_status_t srtp_validate_rtp_header (void * rtp_hdr ,
83102 int * pkt_octet_len )
84103{
@@ -3030,18 +3049,31 @@ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc)
30303049{
30313050 srtp_stream_ctx_t * stream ;
30323051 srtp_err_status_t status ;
3052+ #if !defined(SRTP_NO_STREAM_LIST )
3053+ uint32_t pos ;
3054+ #endif
30333055
30343056 /* sanity check arguments */
3035- if (session == NULL )
3057+ if (session == NULL ) {
30363058 return srtp_err_status_bad_param ;
3059+ }
30373060
30383061 /* find and remove stream from the list */
3062+ #if !defined(SRTP_NO_STREAM_LIST )
3063+ pos = srtp_stream_list_find (session -> stream_list , ssrc );
3064+ if (pos >= srtp_stream_list_size (session -> stream_list ))
3065+ return srtp_err_status_no_ctx ;
3066+
3067+ stream = srtp_stream_list_get_at (session -> stream_list , pos );
3068+ srtp_stream_list_remove_at (session -> stream_list , pos );
3069+ #else
30393070 stream = srtp_stream_list_get (session -> stream_list , ssrc );
30403071 if (stream == NULL ) {
30413072 return srtp_err_status_no_ctx ;
30423073 }
30433074
30443075 srtp_stream_list_remove (session -> stream_list , stream );
3076+ #endif
30453077
30463078 /* deallocate the stream */
30473079 status = srtp_stream_dealloc (stream , session -> stream_template );
@@ -4840,11 +4872,11 @@ srtp_err_status_t srtp_get_stream_roc(srtp_t session,
48404872
48414873#ifndef SRTP_NO_STREAM_LIST
48424874
4843- /* in the default implementation, we have an intrusive doubly-linked list */
48444875typedef struct srtp_stream_list_ctx_t_ {
4845- /* a stub stream that just holds pointers to the beginning and end of the
4846- * list */
4847- srtp_stream_ctx_t data ;
4876+ uint32_t * ssrcs ;
4877+ srtp_stream_ctx_t * * streams ;
4878+ uint32_t size ;
4879+ uint32_t capacity ;
48484880} srtp_stream_list_ctx_t_ ;
48494881
48504882srtp_err_status_t srtp_stream_list_alloc (srtp_stream_list_t * list_ptr )
@@ -4855,73 +4887,204 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr)
48554887 return srtp_err_status_alloc_fail ;
48564888 }
48574889
4858- list -> data .next = NULL ;
4859- list -> data .prev = NULL ;
4860-
48614890 * list_ptr = list ;
48624891 return srtp_err_status_ok ;
48634892}
48644893
48654894srtp_err_status_t srtp_stream_list_dealloc (srtp_stream_list_t list )
48664895{
48674896 /* list must be empty */
4868- if (list -> data . next ) {
4897+ if (list -> size != 0u ) {
48694898 return srtp_err_status_fail ;
48704899 }
4900+ srtp_crypto_free (list -> streams );
4901+ srtp_crypto_free (list -> ssrcs );
48714902 srtp_crypto_free (list );
48724903 return srtp_err_status_ok ;
48734904}
48744905
4906+ static inline uint32_t srtp_stream_list_size (srtp_stream_list_t list )
4907+ {
4908+ return list -> size ;
4909+ }
4910+
4911+ static srtp_err_status_t srtp_stream_list_reserve (srtp_stream_list_t list ,
4912+ uint32_t new_capacity )
4913+ {
4914+ if (new_capacity > list -> capacity ) {
4915+ uint32_t * ssrcs ;
4916+ srtp_stream_ctx_t * * stream_ptrs ;
4917+
4918+ if (new_capacity > (UINT32_MAX - 15u ))
4919+ return srtp_err_status_alloc_fail ;
4920+
4921+ new_capacity = (new_capacity + 15u ) & ~((uint32_t )15u );
4922+
4923+ ssrcs = (uint32_t * )srtp_crypto_alloc ((size_t )new_capacity *
4924+ sizeof (uint32_t ));
4925+ if (!ssrcs )
4926+ return srtp_err_status_alloc_fail ;
4927+ stream_ptrs = (srtp_stream_ctx_t * * )srtp_crypto_alloc (
4928+ (size_t )new_capacity * sizeof (srtp_stream_ctx_t * ));
4929+ if (!stream_ptrs ) {
4930+ srtp_crypto_free (ssrcs );
4931+ return srtp_err_status_alloc_fail ;
4932+ }
4933+
4934+ if (list -> size > 0u ) {
4935+ memcpy (ssrcs , list -> ssrcs , (size_t )list -> size * sizeof (uint32_t ));
4936+ memcpy (stream_ptrs , list -> streams ,
4937+ (size_t )list -> size * sizeof (srtp_stream_ctx_t * ));
4938+ }
4939+
4940+ srtp_crypto_free (list -> ssrcs );
4941+ srtp_crypto_free (list -> streams );
4942+ list -> streams = stream_ptrs ;
4943+ list -> ssrcs = ssrcs ;
4944+
4945+ list -> capacity = new_capacity ;
4946+ }
4947+
4948+ return srtp_err_status_ok ;
4949+ }
4950+
48754951srtp_err_status_t srtp_stream_list_insert (srtp_stream_list_t list ,
48764952 srtp_stream_t stream )
48774953{
4878- /* insert at the head of the list */
4879- stream -> next = list -> data . next ;
4880- if (stream -> next != NULL ) {
4881- stream -> next -> prev = stream ;
4882- }
4883- list -> data . next = stream ;
4884- stream -> prev = & ( list -> data ) ;
4954+ uint32_t pos ;
4955+ srtp_err_status_t status = srtp_stream_list_reserve ( list , list -> size + 1u ) ;
4956+ if (status )
4957+ return status ;
4958+ pos = list -> size ++ ;
4959+ list -> ssrcs [ pos ] = stream -> ssrc ;
4960+ list -> streams [ pos ] = stream ;
48854961
48864962 return srtp_err_status_ok ;
48874963}
48884964
4889- srtp_stream_t srtp_stream_list_get (srtp_stream_list_t list , uint32_t ssrc )
4965+ static uint32_t srtp_stream_list_find (srtp_stream_list_t list , uint32_t ssrc )
48904966{
4891- /* walk down list until ssrc is found */
4892- srtp_stream_t stream = list -> data .next ;
4893- while (stream != NULL ) {
4894- if (stream -> ssrc == ssrc ) {
4895- return stream ;
4967+ #if defined(__SSE2__ )
4968+ const uint32_t * const ssrcs = list -> ssrcs ;
4969+ const __m128i mm_ssrc = _mm_set1_epi32 (ssrc );
4970+ uint32_t pos = 0u , n = (list -> size + 7u ) & ~(uint32_t )(7u );
4971+ for (uint32_t m = n & ~(uint32_t )(15u ); pos < m ; pos += 16u ) {
4972+ __m128i mm1 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos ));
4973+ __m128i mm2 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 4u ));
4974+ __m128i mm3 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 8u ));
4975+ __m128i mm4 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 12u ));
4976+ mm1 = _mm_cmpeq_epi32 (mm1 , mm_ssrc );
4977+ mm2 = _mm_cmpeq_epi32 (mm2 , mm_ssrc );
4978+ mm3 = _mm_cmpeq_epi32 (mm3 , mm_ssrc );
4979+ mm4 = _mm_cmpeq_epi32 (mm4 , mm_ssrc );
4980+ mm1 = _mm_packs_epi32 (mm1 , mm2 );
4981+ mm3 = _mm_packs_epi32 (mm3 , mm4 );
4982+ mm1 = _mm_packs_epi16 (mm1 , mm3 );
4983+ uint32_t mask = _mm_movemask_epi8 (mm1 );
4984+ if (mask ) {
4985+ #if defined(_MSC_VER )
4986+ unsigned long bit_pos ;
4987+ _BitScanForward (& bit_pos , mask );
4988+ pos += bit_pos ;
4989+ #else
4990+ pos += __builtin_ctz (mask );
4991+ #endif
4992+
4993+ goto done ;
4994+ }
4995+ }
4996+
4997+ if (pos < n ) {
4998+ __m128i mm1 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos ));
4999+ __m128i mm2 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 4u ));
5000+ mm1 = _mm_cmpeq_epi32 (mm1 , mm_ssrc );
5001+ mm2 = _mm_cmpeq_epi32 (mm2 , mm_ssrc );
5002+ mm1 = _mm_packs_epi32 (mm1 , mm2 );
5003+
5004+ uint32_t mask = _mm_movemask_epi8 (mm1 );
5005+ if (mask ) {
5006+ #if defined(_MSC_VER )
5007+ unsigned long bit_pos ;
5008+ _BitScanForward (& bit_pos , mask );
5009+ pos += bit_pos / 2u ;
5010+ #else
5011+ pos += __builtin_ctz (mask ) / 2u ;
5012+ #endif
5013+ goto done ;
48965014 }
4897- stream = stream -> next ;
5015+
5016+ pos += 8u ;
5017+ }
5018+
5019+ done :
5020+ return pos ;
5021+ #else
5022+ /* walk down list until ssrc is found */
5023+ uint32_t pos = 0u , n = list -> size ;
5024+ for (; pos < n ; ++ pos ) {
5025+ if (list -> ssrcs [pos ] == ssrc )
5026+ break ;
48985027 }
48995028
5029+ return pos ;
5030+ #endif
5031+ }
5032+
5033+ static inline srtp_stream_t srtp_stream_list_get_at (srtp_stream_list_t list ,
5034+ uint32_t pos )
5035+ {
5036+ return list -> streams [pos ];
5037+ }
5038+
5039+ srtp_stream_t srtp_stream_list_get (srtp_stream_list_t list , uint32_t ssrc )
5040+ {
5041+ uint32_t pos = srtp_stream_list_find (list , ssrc );
5042+ if (pos < list -> size )
5043+ return list -> streams [pos ];
5044+
49005045 /* we haven't found our ssrc, so return a null */
49015046 return NULL ;
49025047}
49035048
4904- void srtp_stream_list_remove (srtp_stream_list_t list ,
4905- srtp_stream_t stream_to_remove )
5049+ static void srtp_stream_list_remove_at (srtp_stream_list_t list , uint32_t pos )
49065050{
4907- ( void ) list ;
5051+ uint32_t tail_size , last_pos ;
49085052
4909- stream_to_remove -> prev -> next = stream_to_remove -> next ;
4910- if (stream_to_remove -> next != NULL ) {
4911- stream_to_remove -> next -> prev = stream_to_remove -> prev ;
5053+ last_pos = -- list -> size ;
5054+ tail_size = last_pos - pos ;
5055+ if (tail_size > 0u ) {
5056+ memmove (list -> streams + pos , list -> streams + pos + 1 ,
5057+ (size_t )tail_size * sizeof (* list -> streams ));
5058+ memmove (list -> ssrcs + pos , list -> ssrcs + pos + 1 ,
5059+ (size_t )tail_size * sizeof (* list -> ssrcs ));
49125060 }
5061+
5062+ list -> streams [last_pos ] = NULL ;
5063+ list -> ssrcs [last_pos ] = 0u ;
5064+ }
5065+
5066+ void srtp_stream_list_remove (srtp_stream_list_t list ,
5067+ srtp_stream_t stream_to_remove )
5068+ {
5069+ uint32_t pos = srtp_stream_list_find (list , stream_to_remove -> ssrc );
5070+ if (pos < list -> size )
5071+ srtp_stream_list_remove_at (list , pos );
49135072}
49145073
49155074void srtp_stream_list_for_each (srtp_stream_list_t list ,
49165075 int (* callback )(srtp_stream_t , void * ),
49175076 void * data )
49185077{
4919- srtp_stream_t stream = list -> data .next ;
4920- while (stream != NULL ) {
4921- srtp_stream_t tmp = stream ;
4922- stream = stream -> next ;
4923- if (callback (tmp , data ))
5078+ uint32_t size = list -> size ;
5079+ for (uint32_t i = 0u ; i < size ;) {
5080+ if (callback (list -> streams [i ], data ))
49245081 break ;
5082+
5083+ /* check if the callback removed the current element */
5084+ if (size == list -> size )
5085+ ++ i ;
5086+ else
5087+ size = list -> size ;
49255088 }
49265089}
49275090
0 commit comments