forked from ACM-Research/Coding-Challenge-S22
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
63 lines (51 loc) · 2.31 KB
/
predict.py
File metadata and controls
63 lines (51 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pandas as pd
from sklearn import preprocessing
from tensorflow import keras
### Setup preprocessing pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
features_cat = [
"cap-shape", "cap-surface", "cap-color","bruises",
"odor","gill-attachment","gill-spacing","gill-size",
"gill-color","stalk-shape","stalk-root",
"stalk-surface-above-ring","stalk-surface-below-ring",
"stalk-color-above-ring","stalk-color-below-ring",
"veil-type","veil-color","ring-number","ring-type",
"spore-print-color","population","habitat"
]
transformer_cat = make_pipeline(
SimpleImputer(strategy="constant", fill_value="NA"),
OneHotEncoder(handle_unknown='ignore'),
)
preprocessor = make_column_transformer(
(transformer_cat, features_cat)
)
dataToFit = pd.read_csv("./mushrooms.csv")
dataToFit.pop('class')
preprocessor = preprocessor.fit(dataToFit)
# Create dataframe of input data for prediction
X_predict = pd.DataFrame(columns=["class", "cap-shape", "cap-surface", "cap-color","bruises",
"odor","gill-attachment","gill-spacing","gill-size",
"gill-color","stalk-shape","stalk-root",
"stalk-surface-above-ring","stalk-surface-below-ring",
"stalk-color-above-ring","stalk-color-below-ring",
"veil-type","veil-color","ring-number","ring-type",
"spore-print-color","population","habitat"],
index = [0],
# mushroom info (this row does not appear in the database used for training/testing): p x s n t p f c n k e e s s w w p w o p k s u
data = [ ["p", "x", "s", "n", "t", "p","f","c","n","k","e","e","s","s","w","w","p","w","o","p","k","s","u"] ])
# Perform preprocessing on prediction input
X_predict.pop('class')
X_predict = preprocessor.transform(X_predict)
X_predict = X_predict.toarray()
# Load pre-trained model
model = keras.models.load_model('saved_model/my_model')
# Make prediction and show result
prediction = model.predict(X_predict)
prediction_class = lambda x : "edible" if (x >= 0.5) else "poisonous"
print(prediction[0], " classification: ", prediction_class(prediction[0][0]))
# print(model.summary())
from keras.utils.vis_utils import plot_model
plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)