Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // Given previous tokens, predict the next token (and runners up)
- let predictNextToken (previousKinds : SyntaxKind[]) : Prediction[] =
- if ios11 then
- let model : MLModel = model.Value // Load the cached model
- let mutable predictions : Prediction[] = [| |]
- // RNNs require external memory
- let mutable lstm_1_h : MLMultiArray = null
- let mutable lstm_1_c : MLMultiArray = null
- // Run the model for each previous token
- let inputKeys1 = [| s_prevVectorizedToken |]
- let inputKeys3 = [| s_prevVectorizedToken; s_lstm_1_h_in; s_lstm_1_c_in |]
- let mutable error : NSError = null
- for kindIndex, prevKind in previousKinds |> Array.indexed do
- // Convert the token to a vector for the model
- let vectorizedToken = CSharpPredictor.kindToVector prevKind
- // The first run doesn't include the memory
- let inputKeys, inputValues = if lstm_1_h <> null then inputKeys3, [| vectorizedToken :> NSObject; lstm_1_h :> NSObject; lstm_1_c :> NSObject |]
- else inputKeys1, [| vectorizedToken :> NSObject |]
- let inputDict = NSDictionary<NSString, NSObject>.FromObjectsAndKeys (inputValues, inputKeys, System.nint inputKeys.Length)
- let inputFeatures = new MLDictionaryFeatureProvider (inputDict, &error)
- // Run the prediction
- match model.GetPrediction (inputFeatures) with
- | _, error when error <> null ->
- Debug.WriteLine (error)
- failwith "Prediction failed"
- | output, _ ->
- lstm_1_h <- output.GetFeatureValue("lstm_1_h_out").MultiArrayValue
- lstm_1_c <- output.GetFeatureValue("lstm_1_c_out").MultiArrayValue
- // If this is the last prediction, store the results
- if kindIndex = previousKinds.Length - 1 then
- predictions <-
- output.GetFeatureValue("nextTokenProbabilities").DictionaryValue
- :> Collections.Generic.IDictionary<NSObject, NSNumber>
- |> Seq.map (fun (x : System.Collections.Generic.KeyValuePair<NSObject, NSNumber>) -> string x.Key, x.Value.DoubleValue)
- |> Seq.filter (fun (_, p) -> p > 1.0e-4)
- |> Seq.sortBy (fun (_, p) -> -p)
- |> Seq.map (fun (tokenName, p) ->
- let kind = CSharpPredictor.stringToSyntaxKind tokenName
- let insertText, formatText = CSharpPredictor.kindToCompletion kind
- kind, insertText, formatText, p)
- |> Array.ofSeq
- predictions
- else
- [| |]
Add Comment
Please, Sign In to add comment