machinelearning icon indicating copy to clipboard operation
machinelearning copied to clipboard

ML.NET NER - Mismatched state_dict sizes: expected 60, but found 126 entries.

Open piercarlo62 opened this issue 1 year ago • 2 comments

Hello, I'm testing the NER capabilities of ML.NET and on training I'm getting following error: Error: Mismatched state_dict sizes: expected 60, but found 126 entries.


System Information:

  • OS & Version: Windows 10
  • ML.NET Version: ML.NET v4.0.0
  • .NET Version: .NET 8.0

Description of the bug on var transformer = estimator.Fit(dataView); -> Mismatched state_dict sizes: expected 60, but found 126 entries

Mismatched state_dict sizes: expected 60, but found 126 entries.
in TorchSharp.torch.nn.Module.load(BinaryReader reader, Boolean strict, IList`1 skip, Dictionary`2 loadedParameters)
   in TorchSharp.torch.nn.Module.load(String location, Boolean strict, IList`1 skip, Dictionary`2 loadedParameters)
   in Microsoft.ML.TorchSharp.NasBert.NasBertTrainer`2.NasBertTrainerBase.CreateModule(IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.TorchSharpBaseTrainer`2.TrainerBase..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input, String modelUrl)
   in Microsoft.ML.TorchSharp.NasBert.NasBertTrainer`2.NasBertTrainerBase..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input, String modelUrl)
   in Microsoft.ML.TorchSharp.NasBert.NerTrainer.Trainer..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.NasBert.NerTrainer.CreateTrainer(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.TorchSharpBaseTrainer`2.Fit(IDataView input)
   in Microsoft.ML.Data.EstimatorChain`1.Fit(IDataView input)
   in Program.Main(String[] args) in C:\Users\pierc\source\repos\ML_NER_TEST\Program.cs: riga 64

Sample Projects

using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.TorchSharp;

namespace ML_NER_TEST
{
    public class Program
    {
        public static void Main(string[] args)
        {
            try
            {
                var context = new MLContext()
                {
                    FallbackToCpu = true,
                    GpuDeviceId = 0
                };

                var labels = context.Data.LoadFromEnumerable(
                    [
                            new Label { Key = "PERSON" },       // People, including fictional.
                            new Label { Key = "NORP" },         // Nationalities or religious or political groups.
                            new Label { Key = "FAC" },          // Buildings, airports, highways, bridges, etc.
                            new Label { Key = "ORG" },          // Companies, agencies, institutions, etc.
                            new Label { Key = "GPE" },          // Countries, cities, states.
                            new Label { Key = "LOC" },          // Non-GPE locations, mountain ranges, bodies of water.
                            new Label { Key = "PRODUCT" },      // Objects, vehicles, foods, etc. (Not services.)
                            new Label { Key = "EVENT" },        // Named hurricanes, battles, wars, sports events, etc.
                            new Label { Key = "WORK_OF_ART" },  // Titles of books, songs, etc.
                            new Label { Key = "LAW" },          // Named documents made into laws.
                            new Label { Key = "LANGUAGE" },     // Any named language.
                            new Label { Key = "DATE" },         // Absolute or relative dates or periods.
                            new Label { Key = "TIME" },         // Times smaller than a day.
                            new Label { Key = "PERCENT" },      // Percentage, including "%".
                            new Label { Key = "MONEY" },        // Monetary values, including unit.
                            new Label { Key = "QUANTITY" },     // Measurements, as of weight or distance.
                            new Label { Key = "ORDINAL" },      // "first", "second", etc.
                            new Label { Key = "CARDINAL" },     // Numerals that do not fall under another type.
                            new Label { Key = "OBJECT" },       // An Object, Entity might be a Spoon, or a Soccer Ball. Needs Sub Categories.
                ]);

                var dataView = context.Data.LoadFromEnumerable(
                    new List<InputTrainingData>([
                        new InputTrainingData()
                    {   
                        // Testing longer than 512 words.
                        Sentence = "Alice and Bob live in the USA",
                        Label = ["PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"]
                    },
                    new InputTrainingData()
                     {
                        Sentence = "Frank and Alice traveled along the California coast.",
                        Label = ["PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY", "0"]
                     },
                    ]));

                var chain = new EstimatorChain<ITransformer>();

                var estimator = chain.Append(context.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
                   .Append(context.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "Predictions"))
                   .Append(context.Transforms.Conversion.MapKeyToValue("Predictions"));

                Console.WriteLine("Training the model...");

                var transformer = estimator.Fit(dataView);

                Console.WriteLine("Model trained!");

                var transformerSchema = transformer.GetOutputSchema(dataView.Schema);

                string sentence = "Alice and Bob live in the USA";
                var engine = context.Model.CreatePredictionEngine<Input, Output>(transformer);

                Console.WriteLine("Predicting...");

                Output predictions = engine.Predict(new Input { Sentence = sentence });

                Console.WriteLine($"Predictions: {sentence} - {string.Join(", ", predictions.Predictions)}");

                transformer.Dispose();
                Console.WriteLine("Success!");
                Console.ReadLine();
            }
            catch (Exception ex)
            {
                Console.WriteLine($"Error: {ex.Message}");
                Console.ReadLine();
            }
        }
        private class Input
        {
            public string Sentence;
            public string[] Label;
        }
        private class Output
        {
            public string[] Predictions;
        }
        public class Label
        {
            public string Key { get; set; }
        }
        private class InputTrainingData
        {
            public string Sentence;
            public string[] Label;
        }
    }
}

Additional context

<Project Sdk="Microsoft.NET.Sdk">

  <PropertyGroup>
    <OutputType>Exe</OutputType>
    <TargetFramework>net8.0</TargetFramework>
    <ImplicitUsings>enable</ImplicitUsings>
    <Nullable>disable</Nullable>
  </PropertyGroup>

  <ItemGroup>
    <PackageReference Include="libtorch-cpu-win-x64" Version="2.5.1" />
    <PackageReference Include="Microsoft.ML" Version="4.0.0" />
    <PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.0" />
    <PackageReference Include="Microsoft.ML.TorchSharp" Version="0.22.0" />
    <PackageReference Include="TorchSharp" Version="0.105.0" />
  </ItemGroup>

</Project>

piercarlo62 avatar Dec 22 '24 08:12 piercarlo62

I'm getting different results from the command line tool (dotnet tool install --global mlnet-win-x64). The VS extension ML.Net Model Builder v17.18.0 is using something like mlnet below, however, for me a c# project such as the one above possibly uses different libraries and ML architectures also gives me a Mismatched state_dict exception. Below is a hint on what libraries the mlnet tool is using. I'm wondering if part of this issue is the tooling might be a release behind?

C:\Users\xxxx>mlnet text-classification --dataset "text-code.txt" --label-col 1 --text-col 0 --has-header true Start Training start text classification env:path: C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5; [snip] restore "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj" --configfile "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\NuGet.config" -r win-x64 /p:UsingToolXliff=false /p:TorchSharpVersion=0.101.5 /p:TorchSharpCudaRuntimeVersion=2.1.0.1 /p:TensorflowRuntimeVersion=2.3.1 /p:BaseIntermediateOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\obj" publish "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj" -r win-x64 -c Release --no-self-contained -o "C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5" --no-restore /p:UsingToolXliff=false /p:TorchSharpVersion=0.101.5 /p:TorchSharpCudaRuntimeVersion=2.1.0.1 /p:TensorflowRuntimeVersion=2.3.1 /p:BaseOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\bin\" /p:BaseIntermediateOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\obj\" start installing runtime in C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5 Determining projects to restore... Restored C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj (in 1.5 min).

torchsharp.cpu -> C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\bin\Release\netstandard2.0\win-x64\torchsharp.cpu.dll torchsharp.cpu -> C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\

install runtime successfully Use train validate split with ratio: 0.1 [Source=AutoMLExperiment-ChildContext, Kind=Trace] [Source=TorchSharpBaseTrainer; TrainModel, Kind=Trace] Starting epoch 0

dha125 avatar Feb 07 '25 19:02 dha125

For with it's worth the mlnet text-classification.... command line tool generates a SampleTextClassification with a csproj file as shown below.

Edit: Rolled back to TorchSharp-cpu Version="0.99.6" and it's working well for me now.

<Project Sdk="Microsoft.NET.Sdk">
  <PropertyGroup>
    <OutputType>Exe</OutputType>
    <TargetFramework>net8.0</TargetFramework>
    <PlatformTarget>x64</PlatformTarget>
  </PropertyGroup>
  <ItemGroup>
    <PackageReference Include="Microsoft.ML" Version="3.0.1" />
    <PackageReference Include="Microsoft.ML.TorchSharp" Version="0.21.0" />
    <PackageReference Include="TorchSharp-cpu" Version="0.101.5" />
  </ItemGroup>
  <ItemGroup Label="SampleTextClassification">
    <None Include="SampleTextClassification.mlnet">
      <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
    </None>
  </ItemGroup>
</Project>

dha125 avatar Feb 08 '25 16:02 dha125