diff --git a/pyAudioAnalysis/audioTrainTest.py b/pyAudioAnalysis/audioTrainTest.py index 32bea38c6..7b7c5c224 100644 --- a/pyAudioAnalysis/audioTrainTest.py +++ b/pyAudioAnalysis/audioTrainTest.py @@ -17,6 +17,33 @@ import plotly.graph_objs as go import sklearn.metrics +# Fix for Arbitrary code execution due to loading of untrusted pickled data +import io +import builtins +import pickle + +safe_builtins = { + 'range', + 'complex', + 'set', + 'frozenset', + 'slice', +} + +class RestrictedUnpickler(pickle.Unpickler): + + def find_class(self, module, name): + # Only allow safe classes from builtins + if module == "builtins" and name in safe_builtins: + return getattr(builtins, name) + """Forbid everything else""" + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % + (module, name)) + +def restricted_loads(s): + # Helper function analogous to pickle.loads() + return RestrictedUnpickler(io.BytesIO(s)).load() + shortTermWindow = 0.050 shortTermStep = 0.050 @@ -502,6 +529,7 @@ def feature_extraction_train_regression(folder_name, mid_window, mid_step, def load_model_knn(knn_model_name, is_regression=False): with open(knn_model_name, "rb") as fo: + restricted_loads(fo.read()) features = cPickle.load(fo) labels = cPickle.load(fo) mean = cPickle.load(fo) @@ -540,6 +568,7 @@ def load_model(model_name, is_regression=False): is regression or not """ with open(model_name + "MEANS", "rb") as fo: + restricted_loads(fo.read()) mean = cPickle.load(fo) std = cPickle.load(fo) if not is_regression: @@ -554,6 +583,7 @@ def load_model(model_name, is_regression=False): std = np.array(std) with open(model_name, 'rb') as fid: + restricted_loads(fid.read()) svm_model = cPickle.load(fid) if is_regression: