Guest User

Untitled

a guest
Jul 19th, 2018
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.87 KB | None | 0 0
  1. // Given previous tokens, predict the next token (and runners up)
  2. let predictNextToken (previousKinds : SyntaxKind[]) : Prediction[] =
  3. if ios11 then
  4. let model : MLModel = model.Value // Load the cached model
  5. let mutable predictions : Prediction[] = [| |]
  6.  
  7. // RNNs require external memory
  8. let mutable lstm_1_h : MLMultiArray = null
  9. let mutable lstm_1_c : MLMultiArray = null
  10.  
  11. // Run the model for each previous token
  12. let inputKeys1 = [| s_prevVectorizedToken |]
  13. let inputKeys3 = [| s_prevVectorizedToken; s_lstm_1_h_in; s_lstm_1_c_in |]
  14. let mutable error : NSError = null
  15. for kindIndex, prevKind in previousKinds |> Array.indexed do
  16.  
  17. // Convert the token to a vector for the model
  18. let vectorizedToken = CSharpPredictor.kindToVector prevKind
  19.  
  20. // The first run doesn't include the memory
  21. let inputKeys, inputValues = if lstm_1_h <> null then inputKeys3, [| vectorizedToken :> NSObject; lstm_1_h :> NSObject; lstm_1_c :> NSObject |]
  22. else inputKeys1, [| vectorizedToken :> NSObject |]
  23. let inputDict = NSDictionary<NSString, NSObject>.FromObjectsAndKeys (inputValues, inputKeys, System.nint inputKeys.Length)
  24. let inputFeatures = new MLDictionaryFeatureProvider (inputDict, &error)
  25.  
  26. // Run the prediction
  27. match model.GetPrediction (inputFeatures) with
  28. | _, error when error <> null ->
  29. Debug.WriteLine (error)
  30. failwith "Prediction failed"
  31. | output, _ ->
  32. lstm_1_h <- output.GetFeatureValue("lstm_1_h_out").MultiArrayValue
  33. lstm_1_c <- output.GetFeatureValue("lstm_1_c_out").MultiArrayValue
  34.  
  35. // If this is the last prediction, store the results
  36. if kindIndex = previousKinds.Length - 1 then
  37. predictions <-
  38. output.GetFeatureValue("nextTokenProbabilities").DictionaryValue
  39. :> Collections.Generic.IDictionary<NSObject, NSNumber>
  40. |> Seq.map (fun (x : System.Collections.Generic.KeyValuePair<NSObject, NSNumber>) -> string x.Key, x.Value.DoubleValue)
  41. |> Seq.filter (fun (_, p) -> p > 1.0e-4)
  42. |> Seq.sortBy (fun (_, p) -> -p)
  43. |> Seq.map (fun (tokenName, p) ->
  44. let kind = CSharpPredictor.stringToSyntaxKind tokenName
  45. let insertText, formatText = CSharpPredictor.kindToCompletion kind
  46. kind, insertText, formatText, p)
  47. |> Array.ofSeq
  48. predictions
  49. else
  50. [| |]
Add Comment
Please, Sign In to add comment