You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
2.1 KiB
Python

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
import statsmodels.api as sm
class OlsModel:
def __init__(self, x, y):
self.x = x
self.y = y
self.results = self.create_model()
def create_model(self):
x = sm.add_constant(self.x)
model = sm.OLS(self.y, x)
results = model.fit()
return results
class MlrModel:
def __init__(self, x, y):
self.x = x
self.y = y
self.results = self.create_model()
def create_model(self):
X_train, X_test, y_train, y_test = train_test_split(self.x, self.y, test_size=0.25, random_state=42)
model = LinearRegression().fit(X_train, y_train)
return model
def ols_calcutate_all(x, qufu_mean_ols_model, qufu_std_ols_model,
kangla_mean_ols_model, kangla_std_ols_model,
yanshen_mean_ols_model, yanshen_std_ols_model):
print("屈服均值: " + str(qufu_mean_ols_model.results.predict(x)) + "\n"
"抗拉均值: " + str(kangla_mean_ols_model.results.predict(x)) + "\n"
"延伸率均值: " + str(yanshen_mean_ols_model.results.predict(x)) + "\n"
"屈服标准差: " + str(qufu_std_ols_model.results.predict(x)) + "\n"
"抗拉标准差: " + str(kangla_std_ols_model.results.predict(x)) + "\n"
"延伸率标准差: " + str(yanshen_std_ols_model.results.predict(x)) + "\n"
)
def mlr_calcutate_all(x, qufu_mean_mlr_model, qufu_std_mlr_model,
kangla_mean_mlr_model, kangla_std_mlr_model,
yanshen_mean_mlr_model, yanshen_std_mlr_model):
print("屈服均值: " + str(qufu_mean_mlr_model.results.predict(x)) + "\n"
"抗拉均值: " + str(kangla_mean_mlr_model.results.predict(x)) + "\n"
"延伸率均值: " + str(yanshen_mean_mlr_model.results.predict(x)) + "\n"
"屈服标准差: " + str(qufu_std_mlr_model.results.predict(x)) + "\n"
"抗拉标准差: " + str(kangla_std_mlr_model.results.predict(x)) + "\n"
"延伸率标准差: " + str(yanshen_std_mlr_model.results.predict(x)) + "\n"
)