CarlosWGama

RN - Classificação Binária

Feb 8th, 2020
189
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import pandas as pd
  2. import numpy as np
  3. from sklearn.preprocessing import LabelEncoder, OneHotEncoder
  4. from sklearn.compose import ColumnTransformer
  5. from keras.models import Sequential
  6. from keras.layers import Dense
  7. from keras.utils import np_utils
  8.  
  9. csv = pd.read_csv('planos.csv', sep=',')
  10. csv = csv.drop(columns=['nome'])
  11.  
  12. #Ajusta
  13. le = LabelEncoder()
  14. csv['estado civil'] = le.fit_transform(csv['estado civil']) #Estado Civil (0-casado|1-solteiro|2-viuvo)
  15. csv['genero'] = le.fit_transform(csv['genero']) #Gênero (0-feminino|1-masculino)
  16. csv['risco'] = le.fit_transform(csv['risco']) #Risco (0-alto|1-baixo|2-medio)
  17. dados = csv.values
  18.  
  19. #Separa
  20. atributos = dados[:,0:3]
  21. classificadores = dados[:,3]
  22. classificadores = np_utils.to_categorical(classificadores)
  23.  
  24. #Ajusta os atributos para classificações binários
  25. ct = ColumnTransformer([('binarios', OneHotEncoder(), [2])], remainder='passthrough')
  26. atributos = ct.fit_transform(atributos)
  27.  
  28. #Criando o modelo
  29. modelo = Sequential()
  30. modelo.add(Dense(units=5, activation='relu'))
  31. modelo.add(Dense(units=5, activation='relu'))
  32. modelo.add(Dense(units=3, activation='softmax')) #A soma de todos não ultrapassa 1
  33.  
  34. modelo.compile(optimizer='adam',  loss = 'categorical_crossentropy', metrics = ['categorical_accuracy'])
  35. modelo.fit(atributos, classificadores, batch_size=50, epochs=500)
  36.  
  37. #Identificando a classificação de um novo usuário
  38. #Estado Civil (0-casado|1-solteiro|2-viuvo)
  39. #Gênero (0-feminino|1-masculino)
  40. novos = np.array([
  41.     [80, 2, 1],
  42.     [27, 0, 0],
  43.     [35, 1, 1]
  44. ])
  45. novos = ct.transform(novos)
  46.  
  47. resultado = modelo.predict(novos)
  48.  
  49. #Ordem Alfabética - Risco (0-alto|1-baixo|2-medio)
  50. print(resultado)
RAW Paste Data