MotionMatching icon indicating copy to clipboard operation
MotionMatching copied to clipboard

Load time improvement

Open jhughes2112 opened this issue 10 months ago • 3 comments

We have a 400mb pose set, and it takes about 11 seconds to load. Iterating has become slower because of it. I did make an effort to use memory mapped files, but in Unity that requires permissions that could trip up developers unduly. I made a small change to load the whole file into memory then construct classes from it in ram. This brought us down to 7.8 seconds, with minimal code change. Note that because AnimationClip is a struct and it's being fed into the poseSet one at a time, whenever the clip list resizes, it has to re-copy all of that data over and over again... a slightly more invasive change would dramatically help by simply setting the list capacity once nClips has been read.

` public bool Deserialize(string path, string fileName, MotionMatchingData mmData, out PoseSet poseSet) { poseSet = new PoseSet(mmData);

		// --------------------
		// Read Skeleton File
		// --------------------
		string skeletonPath = Path.Combine(path, fileName + ".mmskeleton");
		if (!File.Exists(skeletonPath))
			return false;

		Skeleton skeleton = new Skeleton();
        PROFILE.BEGIN_SAMPLE_PROFILING("Skeleton File Read");
		byte[] skeletonData = File.ReadAllBytes(skeletonPath);
		PROFILE.END_AND_PRINT_SAMPLE_PROFILING("Skeleton File Read");
		using (var ms = new MemoryStream(skeletonData))
		{
			using (var reader = new BinaryReader(ms, Encoding.UTF8))
			{
				uint nJoints = reader.ReadUInt32();
				for (int i = 0; i < nJoints; i++)
				{
					string jointName = reader.ReadString();
					uint jointIndex = reader.ReadUInt32();
					uint jointParentIndex = reader.ReadUInt32();
					float3 jointLocalOffset = ReadFloat3(reader);
					HumanBodyBones jointType = (HumanBodyBones)reader.ReadUInt32();
					skeleton.AddJoint(new Skeleton.Joint(jointName, (int)jointIndex, (int)jointParentIndex, jointLocalOffset, jointType));
				}
			}
		}
		poseSet.SetSkeletonFromFile(skeleton);

		// --------------------
		// Read Pose File
		// --------------------
		string posePath = Path.Combine(path, fileName + ".mmpose");
		if (!File.Exists(posePath))
			return false;

        PROFILE.BEGIN_SAMPLE_PROFILING("Pose File Read");
		byte[] poseData = File.ReadAllBytes(posePath);
        PROFILE.END_AND_PRINT_SAMPLE_PROFILING("Pose File Read");
		using (var ms = new MemoryStream(poseData))
		{
			using (var reader = new BinaryReader(ms, Encoding.UTF8))
			{
				uint nClips = reader.ReadUInt32();
				for (int i = 0; i < nClips; i++)
				{
					uint start = reader.ReadUInt32();
					uint end = reader.ReadUInt32();
					float frameTime = reader.ReadSingle();
					poseSet.AddAnimationClipDeserialized(new PoseSet.AnimationClip((int)start, (int)end, frameTime));
				}

				uint nPoses = reader.ReadUInt32();
				uint nJoints = reader.ReadUInt32();
				uint nTags = reader.ReadUInt32();
				Debug.Assert(nJoints == skeleton.Joints.Count, "Number of joints in skeleton and pose do not match");

				PoseVector[] poses = new PoseVector[nPoses];
				for (int i = 0; i < nPoses; i++)
				{
					PoseVector pose = new PoseVector();
					pose.JointLocalPositions = ReadFloat3Array(reader, nJoints);
					pose.JointLocalRotations = ReadQuaternionArray(reader, nJoints);
					pose.JointLocalVelocities = ReadFloat3Array(reader, nJoints);
					pose.JointLocalAngularVelocities = ReadFloat3Array(reader, nJoints);
					pose.LeftFootContact = reader.ReadUInt32() == 1u;
					pose.RightFootContact = reader.ReadUInt32() == 1u;
					poses[i] = pose;
				}
				poseSet.AddClipDeserialized(poses);

				for (int i = 0; i < nTags; i++)
				{
					string name = reader.ReadString();
					int nRanges = (int)reader.ReadUInt32();
					List<int> tagStarts = new List<int>(nRanges);
					List<int> tagEnds = new List<int>(nRanges);
					for (int r = 0; r < nRanges; r++)
					{
						tagStarts.Add((int)reader.ReadUInt32());
						tagEnds.Add((int)reader.ReadUInt32());
					}
					poseSet.AddTagDeserialized(name, tagStarts, tagEnds);
				}
				poseSet.ConvertTagsToNativeArrays();
			}
		}
		return true;
	}

`

jhughes2112 avatar Mar 27 '25 17:03 jhughes2112

Of course, after a little further investigation, it turns out that 99.9% of the remaining time is spent simply parsing the pose data, which makes sense. Tons of memory allocations happening there.

jhughes2112 avatar Mar 27 '25 17:03 jhughes2112

This revision will get your load times WAY down. 1.6 seconds on my 400mb data set.

If you make this tiny change to PoseSet.cs:

        public void AddClip(PoseVector pose)
        {
            Poses.Add(pose);
        }
		public void SetPoseCapacity(uint numPoses)
		{
			Poses.Capacity = (int)numPoses;
		}
		public void SetClipCapacity(uint count)
		{
			Clips.Capacity = (int)count;
		}

Then replace your Deserialize function with this optimized version:

		public bool Deserialize(string path, string fileName, MotionMatchingData mmData, out PoseSet poseSet)
		{
			poseSet = new PoseSet(mmData);

			// --------------------
			// Read Skeleton File
			// --------------------
			string skeletonPath = Path.Combine(path, fileName + ".mmskeleton");
			if (!File.Exists(skeletonPath))
				return false;

			Skeleton skeleton = new Skeleton();
            PROFILE.BEGIN_SAMPLE_PROFILING("Skeleton File Read");
			byte[] skeletonData = File.ReadAllBytes(skeletonPath);
			PROFILE.END_AND_PRINT_SAMPLE_PROFILING("Skeleton File Read");
			using (var ms = new MemoryStream(skeletonData))
			{
				using (var reader = new BinaryReader(ms, Encoding.UTF8))
				{
					uint nJoints = reader.ReadUInt32();
					skeleton.Joints.Capacity = (int)nJoints;
					for (int i = 0; i < nJoints; i++)
					{
						string jointName = reader.ReadString();
						uint jointIndex = reader.ReadUInt32();
						uint jointParentIndex = reader.ReadUInt32();
						float3 jointLocalOffset = ReadFloat3(reader);
						HumanBodyBones jointType = (HumanBodyBones)reader.ReadUInt32();
						skeleton.AddJoint(new Skeleton.Joint(jointName, (int)jointIndex, (int)jointParentIndex, jointLocalOffset, jointType));
					}
				}
			}
			poseSet.SetSkeletonFromFile(skeleton);

			// --------------------
			// Read Pose File
			// --------------------
			string posePath = Path.Combine(path, fileName + ".mmpose");
			if (!File.Exists(posePath))
				return false;

            PROFILE.BEGIN_SAMPLE_PROFILING("Pose File Read");
			byte[] poseData = File.ReadAllBytes(posePath);
            PROFILE.END_AND_PRINT_SAMPLE_PROFILING("Pose File Read");
			using (var ms = new MemoryStream(poseData))
			{
				using (var reader = new BinaryReader(ms, Encoding.UTF8))
				{
					uint nClips = reader.ReadUInt32();
					poseSet.SetClipCapacity(nClips);
					for (int i = 0; i < nClips; i++)
					{
						uint start = reader.ReadUInt32();
						uint end = reader.ReadUInt32();
						float frameTime = reader.ReadSingle();
						poseSet.AddAnimationClipDeserialized(new PoseSet.AnimationClip((int)start, (int)end, frameTime));
					}

					uint nPoses = reader.ReadUInt32();
					uint nJoints = reader.ReadUInt32();
					uint nTags = reader.ReadUInt32();
					Debug.Assert(nJoints == skeleton.Joints.Count, "Number of joints in skeleton and pose do not match");

					// Precompute sizes for the buffers (they remain constant across iterations)
					int float3BufferSize = (int)nJoints * 3 * sizeof(float);
					int quaternionBufferSize = (int)nJoints * 4 * sizeof(float);

					// Allocate reusable buffers once outside the loop
					byte[] float3Buffer = new byte[float3BufferSize];
					byte[] quaternionBuffer = new byte[quaternionBufferSize];

					PROFILE.BEGIN_SAMPLE_PROFILING("Parse Poses");
					poseSet.SetPoseCapacity(nPoses);
					for (int i = 0; i < nPoses; i++)
					{
						PoseVector pose = new PoseVector();

						// --- Read JointLocalPositions ---
						// Reuse float3Buffer for positions
						reader.Read(float3Buffer, 0, float3BufferSize);
						Span<float> positionsSpan = MemoryMarshal.Cast<byte, float>(float3Buffer);
						pose.JointLocalPositions = new float3[nJoints];
						for (int j = 0; j < nJoints; j++)
						{
							pose.JointLocalPositions[j] = new float3(
								positionsSpan[j * 3],
								positionsSpan[j * 3 + 1],
								positionsSpan[j * 3 + 2]
							);
						}

						// --- Read JointLocalRotations ---
						// Use quaternionBuffer for rotations (4 floats per joint)
						reader.Read(quaternionBuffer, 0, quaternionBufferSize);
						Span<float> rotationsSpan = MemoryMarshal.Cast<byte, float>(quaternionBuffer);
						pose.JointLocalRotations = new quaternion[nJoints];
						for (int j = 0; j < nJoints; j++)
						{
							pose.JointLocalRotations[j] = new quaternion(
								rotationsSpan[j * 4],
								rotationsSpan[j * 4 + 1],
								rotationsSpan[j * 4 + 2],
								rotationsSpan[j * 4 + 3]
							);
						}

						// --- Read JointLocalVelocities ---
						reader.Read(float3Buffer, 0, float3BufferSize);
						Span<float> velocitiesSpan = MemoryMarshal.Cast<byte, float>(float3Buffer);
						pose.JointLocalVelocities = new float3[nJoints];
						for (int j = 0; j < nJoints; j++)
						{
							pose.JointLocalVelocities[j] = new float3(
								velocitiesSpan[j * 3],
								velocitiesSpan[j * 3 + 1],
								velocitiesSpan[j * 3 + 2]
							);
						}

						// --- Read JointLocalAngularVelocities ---
						reader.Read(float3Buffer, 0, float3BufferSize);
						Span<float> angularVelocitiesSpan = MemoryMarshal.Cast<byte, float>(float3Buffer);
						pose.JointLocalAngularVelocities = new float3[nJoints];
						for (int j = 0; j < nJoints; j++)
						{
							pose.JointLocalAngularVelocities[j] = new float3(
								angularVelocitiesSpan[j * 3],
								angularVelocitiesSpan[j * 3 + 1],
								angularVelocitiesSpan[j * 3 + 2]
							);
						}

						// --- Read contact flags ---
						pose.LeftFootContact = reader.ReadUInt32() == 1u;
						pose.RightFootContact = reader.ReadUInt32() == 1u;

						poseSet.AddClip(pose);
					}
					PROFILE.END_AND_PRINT_SAMPLE_PROFILING("Parse Poses");

					for (int i = 0; i < nTags; i++)
					{
						string name = reader.ReadString();
						int nRanges = (int)reader.ReadUInt32();
						List<int> tagStarts = new List<int>(nRanges);
						List<int> tagEnds = new List<int>(nRanges);
						for (int r = 0; r < nRanges; r++)
						{
							tagStarts.Add((int)reader.ReadUInt32());
							tagEnds.Add((int)reader.ReadUInt32());
						}
						poseSet.AddTagDeserialized(name, tagStarts, tagEnds);
					}
					poseSet.ConvertTagsToNativeArrays();
				}
			}
			return true;
		}

jhughes2112 avatar Mar 27 '25 17:03 jhughes2112

Thank you very much for the improvements!! :D I don't have much time these days, but I will integrate the changes as soon as possible!

JLPM22 avatar Apr 02 '25 10:04 JLPM22