Advertisement
Guest User

Untitled

a guest
Jul 17th, 2019
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.16 KB | None | 0 0
  1. def my_input_fn(data_file, num_epochs, batch_size):
  2. dataset = tf.data.experimental.make_csv_dataset(
  3. data_file,
  4. batch_size=batch_size,
  5. column_names=_CSV_COLUMNS, # ['int1', 'int2', 'int3', 'int4']
  6. label_name='int4',
  7. na_value="?",
  8. num_epochs=num_epochs,
  9. ignore_errors=True)
  10. return dataset
  11.  
  12. train_inpf = functools.partial(my_input_fn, train_file, num_epochs=2, shuffle=True, batch_size=32)
  13. test_inpf = functools.partial(my_input_fn, test_file, num_epochs=1, shuffle=False, batch_size=1)
  14.  
  15. col1 = tf.feature_column.categorical_column_with_vocabulary_list(
  16. 'int1', column_uniques_lists['int1'], dtype=tf.int64)
  17.  
  18. col2 = tf.feature_column.categorical_column_with_vocabulary_list(
  19. 'int2', column_uniques_lists['int2'], dtype=tf.int64)
  20.  
  21.  
  22. col3 = tf.feature_column.categorical_column_with_vocabulary_list(
  23. 'int3', column_uniques_lists['int3'], dtype=tf.int64)
  24.  
  25. my_categorical_columns = [col1,col2,col3]
  26.  
  27. classifier = tf.estimator.LinearClassifier(feature_columns=my_categorical_columns, n_classes=len(column_uniques_lists['int4']), model_dir='.\SaveLC\model_dir')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement