diff --git a/LinearRegression/LinearRegression_scikit-learn.py b/LinearRegression/LinearRegression_scikit-learn.py index 8036d58..e82c61a 100644 --- a/LinearRegression/LinearRegression_scikit-learn.py +++ b/LinearRegression/LinearRegression_scikit-learn.py @@ -14,7 +14,7 @@ def linearRegression(): scaler = StandardScaler() scaler.fit(X) x_train = scaler.transform(X) - x_test = scaler.transform(np.array([1650,3])) + x_test = scaler.transform(np.array([[1650,3]])) # 线性模型拟合 model = linear_model.LinearRegression() @@ -39,4 +39,4 @@ def loadnpy_data(fileName): if __name__ == "__main__": - linearRegression() \ No newline at end of file + linearRegression()