mirror of
https://github.com/simtactics/servo.git
synced 2025-03-15 08:21:22 +00:00
231 lines
No EOL
6.9 KiB
Text
231 lines
No EOL
6.9 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"#r \"nuget:Microsoft.ML,1.5.1\"\n",
|
|
"#r \"nuget:Microsoft.ML.FastTree,0.17.1\""
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"using System;\n",
|
|
"using System.Collections.Generic;\n",
|
|
"using System.IO;\n",
|
|
"using System.Linq;\n",
|
|
"using Microsoft.ML;\n",
|
|
"using Microsoft.ML.Data;"
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"class ModelInput\n",
|
|
"{\n",
|
|
" [ColumnName(\"Need\"), LoadColumn(0)]\n",
|
|
" public string Need { get; set; }\n",
|
|
"\n",
|
|
"\n",
|
|
" [ColumnName(\"Percentage\"), LoadColumn(1)]\n",
|
|
" public float Percentage { get; set; }\n",
|
|
"\n",
|
|
"\n",
|
|
" [ColumnName(\"Priority\"), LoadColumn(2)]\n",
|
|
" public string Priority { get; set; }\n",
|
|
"}\n",
|
|
"\n",
|
|
"public class ModelOutput\n",
|
|
"{\n",
|
|
" // ColumnName attribute is used to change the column name from\n",
|
|
" // its default value, which is the name of the field.\n",
|
|
" [ColumnName(\"PredictedLabel\")]\n",
|
|
" public String Prediction { get; set; }\n",
|
|
" public float[] Score { get; set; }\n",
|
|
"}"
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"IEstimator<ITransformer> BuildTrainingPipeline(MLContext context)\n",
|
|
"{\n",
|
|
" // Data process configuration with pipeline data transformations \n",
|
|
" var dataProcessPipeline = context.Transforms.Conversion.MapValueToKey(\"Priority\", \"Priority\")\n",
|
|
" .Append(context.Transforms.Categorical.OneHotEncoding(new[] { new InputOutputColumnPair(\"Need\", \"Need\") }))\n",
|
|
" .Append(context.Transforms.Concatenate(\"Features\", new[] { \"Need\", \"Percentage\" }))\n",
|
|
" .AppendCacheCheckpoint(context);\n",
|
|
" // Set the training algorithm \n",
|
|
" var trainer = context.MulticlassClassification.Trainers.OneVersusAll(context.BinaryClassification.Trainers.FastTree(labelColumnName: \"Priority\", featureColumnName: \"Features\"), labelColumnName: \"Priority\")\n",
|
|
" .Append(context.Transforms.Conversion.MapKeyToValue(\"PredictedLabel\", \"PredictedLabel\"));\n",
|
|
"\n",
|
|
" var trainingPipeline = dataProcessPipeline.Append(trainer);\n",
|
|
"\n",
|
|
" return trainingPipeline;\n",
|
|
"}"
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"ITransformer TrainModel(MLContext context, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)\n",
|
|
"{\n",
|
|
" Console.WriteLine(\"=============== Training model ===============\");\n",
|
|
"\n",
|
|
" var model = trainingPipeline.Fit(trainingDataView);\n",
|
|
"\n",
|
|
" Console.WriteLine(\"=============== End of training process ===============\");\n",
|
|
" return model;\n",
|
|
"}"
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)\n",
|
|
"{\n",
|
|
" var modelPath = Path.Combine(\"models\", modelRelativePath);\n",
|
|
" var path = Path.GetFullPath(modelPath);\n",
|
|
" // Save/persist the trained model to a .ZIP file\n",
|
|
" Console.WriteLine($\"=============== Saving the model ===============\");\n",
|
|
" mlContext.Model.Save(mlModel, modelInputSchema, path);\n",
|
|
" Console.WriteLine(\"The model is saved to {0}\", path);\n",
|
|
"}"
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"PredictionEngine<ModelInput, ModelOutput> CreatePredictionEngine()\n",
|
|
"{\n",
|
|
" var modelPath = Path.Combine(\"models\", \"Servo.zip\");\n",
|
|
" var MLNetModelPath = Path.GetFullPath(modelPath);\n",
|
|
"\n",
|
|
" // Create new MLContext\n",
|
|
" MLContext mlContext = new MLContext();\n",
|
|
"\n",
|
|
" // Load model & create prediction engine\n",
|
|
" var mlModel = mlContext.Model.Load(MLNetModelPath, out var modelInputSchema);\n",
|
|
" var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);\n",
|
|
"\n",
|
|
" return predEngine;\n",
|
|
"}"
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"void Predictions(ModelInput[] inputs)\n",
|
|
"{\n",
|
|
" foreach (var input in inputs)\n",
|
|
" {\n",
|
|
" var engine = new Lazy<PredictionEngine<ModelInput, ModelOutput>>(CreatePredictionEngine);\n",
|
|
" var result = engine.Value.Predict(input);\n",
|
|
" var priority = false;\n",
|
|
" \n",
|
|
" switch (result.Prediction)\n",
|
|
" {\n",
|
|
" default:\n",
|
|
" case \"FALSE\":\n",
|
|
" priority = false;\n",
|
|
" break;\n",
|
|
" case \"TRUE\":\n",
|
|
" priority = true;\n",
|
|
" break;\n",
|
|
" }\n",
|
|
" \n",
|
|
" Console.WriteLine($\"Need: {input.Need}{Environment.NewLine}\" + \n",
|
|
" $\"Percentage: {input.Percentage}{Environment.NewLine}\" + \n",
|
|
" $\"Predicted Priority value {priority}{Environment.NewLine}\" + \n",
|
|
" $\"Predicted Priority scores: [{string.Join(\",\", result.Score)}]\");\n",
|
|
" }\n",
|
|
"}"
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"var sampleData = new ModelInput[]\n",
|
|
"{\n",
|
|
" new ModelInput\n",
|
|
" {\n",
|
|
" Need = \"Social\",\n",
|
|
" Percentage = 27F,\n",
|
|
" },\n",
|
|
" new ModelInput\n",
|
|
" {\n",
|
|
" Need = \"Bladder\",\n",
|
|
" Percentage = 76F,\n",
|
|
" },\n",
|
|
" new ModelInput\n",
|
|
" {\n",
|
|
" Need = \"Energy\",\n",
|
|
" Percentage = 55F,\n",
|
|
" },\n",
|
|
" new ModelInput\n",
|
|
" {\n",
|
|
" Need = \"Social\",\n",
|
|
" Percentage = 20F,\n",
|
|
" },\n",
|
|
" new ModelInput\n",
|
|
" {\n",
|
|
" Need = \"Fun\",\n",
|
|
" Percentage = 14F,\n",
|
|
" }\n",
|
|
"};"
|
|
],
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"source": [
|
|
"Predictions(sampleData);"
|
|
],
|
|
"outputs": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".NET (C#)",
|
|
"language": "C#",
|
|
"name": ".net-csharp"
|
|
},
|
|
"language_info": {
|
|
"file_extension": ".cs",
|
|
"mimetype": "text/x-csharp",
|
|
"name": "C#",
|
|
"pygments_lexer": "csharp",
|
|
"version": "8.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
} |