ML.NET NER - Mismatched state_dict sizes: expected 60, but found 126 entries.
Hello, I'm testing the NER capabilities of ML.NET and on training I'm getting following error: Error: Mismatched state_dict sizes: expected 60, but found 126 entries.
System Information:
- OS & Version: Windows 10
- ML.NET Version: ML.NET v4.0.0
- .NET Version: .NET 8.0
Description of the bug
on var transformer = estimator.Fit(dataView); -> Mismatched state_dict sizes: expected 60, but found 126 entries
Mismatched state_dict sizes: expected 60, but found 126 entries.
in TorchSharp.torch.nn.Module.load(BinaryReader reader, Boolean strict, IList`1 skip, Dictionary`2 loadedParameters)
in TorchSharp.torch.nn.Module.load(String location, Boolean strict, IList`1 skip, Dictionary`2 loadedParameters)
in Microsoft.ML.TorchSharp.NasBert.NasBertTrainer`2.NasBertTrainerBase.CreateModule(IChannel ch, IDataView input)
in Microsoft.ML.TorchSharp.TorchSharpBaseTrainer`2.TrainerBase..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input, String modelUrl)
in Microsoft.ML.TorchSharp.NasBert.NasBertTrainer`2.NasBertTrainerBase..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input, String modelUrl)
in Microsoft.ML.TorchSharp.NasBert.NerTrainer.Trainer..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input)
in Microsoft.ML.TorchSharp.NasBert.NerTrainer.CreateTrainer(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input)
in Microsoft.ML.TorchSharp.TorchSharpBaseTrainer`2.Fit(IDataView input)
in Microsoft.ML.Data.EstimatorChain`1.Fit(IDataView input)
in Program.Main(String[] args) in C:\Users\pierc\source\repos\ML_NER_TEST\Program.cs: riga 64
Sample Projects
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.TorchSharp;
namespace ML_NER_TEST
{
public class Program
{
public static void Main(string[] args)
{
try
{
var context = new MLContext()
{
FallbackToCpu = true,
GpuDeviceId = 0
};
var labels = context.Data.LoadFromEnumerable(
[
new Label { Key = "PERSON" }, // People, including fictional.
new Label { Key = "NORP" }, // Nationalities or religious or political groups.
new Label { Key = "FAC" }, // Buildings, airports, highways, bridges, etc.
new Label { Key = "ORG" }, // Companies, agencies, institutions, etc.
new Label { Key = "GPE" }, // Countries, cities, states.
new Label { Key = "LOC" }, // Non-GPE locations, mountain ranges, bodies of water.
new Label { Key = "PRODUCT" }, // Objects, vehicles, foods, etc. (Not services.)
new Label { Key = "EVENT" }, // Named hurricanes, battles, wars, sports events, etc.
new Label { Key = "WORK_OF_ART" }, // Titles of books, songs, etc.
new Label { Key = "LAW" }, // Named documents made into laws.
new Label { Key = "LANGUAGE" }, // Any named language.
new Label { Key = "DATE" }, // Absolute or relative dates or periods.
new Label { Key = "TIME" }, // Times smaller than a day.
new Label { Key = "PERCENT" }, // Percentage, including "%".
new Label { Key = "MONEY" }, // Monetary values, including unit.
new Label { Key = "QUANTITY" }, // Measurements, as of weight or distance.
new Label { Key = "ORDINAL" }, // "first", "second", etc.
new Label { Key = "CARDINAL" }, // Numerals that do not fall under another type.
new Label { Key = "OBJECT" }, // An Object, Entity might be a Spoon, or a Soccer Ball. Needs Sub Categories.
]);
var dataView = context.Data.LoadFromEnumerable(
new List<InputTrainingData>([
new InputTrainingData()
{
// Testing longer than 512 words.
Sentence = "Alice and Bob live in the USA",
Label = ["PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"]
},
new InputTrainingData()
{
Sentence = "Frank and Alice traveled along the California coast.",
Label = ["PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY", "0"]
},
]));
var chain = new EstimatorChain<ITransformer>();
var estimator = chain.Append(context.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
.Append(context.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "Predictions"))
.Append(context.Transforms.Conversion.MapKeyToValue("Predictions"));
Console.WriteLine("Training the model...");
var transformer = estimator.Fit(dataView);
Console.WriteLine("Model trained!");
var transformerSchema = transformer.GetOutputSchema(dataView.Schema);
string sentence = "Alice and Bob live in the USA";
var engine = context.Model.CreatePredictionEngine<Input, Output>(transformer);
Console.WriteLine("Predicting...");
Output predictions = engine.Predict(new Input { Sentence = sentence });
Console.WriteLine($"Predictions: {sentence} - {string.Join(", ", predictions.Predictions)}");
transformer.Dispose();
Console.WriteLine("Success!");
Console.ReadLine();
}
catch (Exception ex)
{
Console.WriteLine($"Error: {ex.Message}");
Console.ReadLine();
}
}
private class Input
{
public string Sentence;
public string[] Label;
}
private class Output
{
public string[] Predictions;
}
public class Label
{
public string Key { get; set; }
}
private class InputTrainingData
{
public string Sentence;
public string[] Label;
}
}
}
Additional context
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>disable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="libtorch-cpu-win-x64" Version="2.5.1" />
<PackageReference Include="Microsoft.ML" Version="4.0.0" />
<PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.0" />
<PackageReference Include="Microsoft.ML.TorchSharp" Version="0.22.0" />
<PackageReference Include="TorchSharp" Version="0.105.0" />
</ItemGroup>
</Project>
I'm getting different results from the command line tool (dotnet tool install --global mlnet-win-x64). The VS extension ML.Net Model Builder v17.18.0 is using something like mlnet below, however, for me a c# project such as the one above possibly uses different libraries and ML architectures also gives me a Mismatched state_dict exception. Below is a hint on what libraries the mlnet tool is using. I'm wondering if part of this issue is the tooling might be a release behind?
C:\Users\xxxx>mlnet text-classification --dataset "text-code.txt" --label-col 1 --text-col 0 --has-header true Start Training start text classification env:path: C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5; [snip] restore "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj" --configfile "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\NuGet.config" -r win-x64 /p:UsingToolXliff=false /p:TorchSharpVersion=0.101.5 /p:TorchSharpCudaRuntimeVersion=2.1.0.1 /p:TensorflowRuntimeVersion=2.3.1 /p:BaseIntermediateOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\obj" publish "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj" -r win-x64 -c Release --no-self-contained -o "C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5" --no-restore /p:UsingToolXliff=false /p:TorchSharpVersion=0.101.5 /p:TorchSharpCudaRuntimeVersion=2.1.0.1 /p:TensorflowRuntimeVersion=2.3.1 /p:BaseOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\bin\" /p:BaseIntermediateOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\obj\" start installing runtime in C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5 Determining projects to restore... Restored C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj (in 1.5 min).
torchsharp.cpu -> C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\bin\Release\netstandard2.0\win-x64\torchsharp.cpu.dll torchsharp.cpu -> C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\
install runtime successfully Use train validate split with ratio: 0.1 [Source=AutoMLExperiment-ChildContext, Kind=Trace] [Source=TorchSharpBaseTrainer; TrainModel, Kind=Trace] Starting epoch 0
For with it's worth the mlnet text-classification.... command line tool generates a SampleTextClassification with a csproj file as shown below.
Edit: Rolled back to TorchSharp-cpu Version="0.99.6" and it's working well for me now.
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.ML" Version="3.0.1" />
<PackageReference Include="Microsoft.ML.TorchSharp" Version="0.21.0" />
<PackageReference Include="TorchSharp-cpu" Version="0.101.5" />
</ItemGroup>
<ItemGroup Label="SampleTextClassification">
<None Include="SampleTextClassification.mlnet">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>