/** * Unit tests for Tensor class */ import { describe, it, expect } from 'vitest'; import { Tensor } from '../tensor'; describe('Tensor', () => { describe('creation', () => { it('should create tensor from nested array', () => { const t = Tensor.from([[1, 2, 3], [4, 5, 6]]); expect(t.shape).toEqual([2, 3]); expect(t.size).toBe(6); expect(t.ndim).toBe(2); }); it('should create tensor from flat array', () => { const t = Tensor.from([1, 2, 3, 4]); expect(t.shape).toEqual([4]); expect(t.size).toBe(4); expect(t.ndim).toBe(1); }); it('should create tensor from Float32Array', () => { const data = new Float32Array([1.0, 2.0, 3.0]); const t = Tensor.from(data); expect(t.dtype).toBe('fp32'); expect(t.size).toBe(3); }); it('should create tensor from Float64Array', () => { const data = new Float64Array([1.0, 2.0, 3.0]); const t = Tensor.from(data); expect(t.dtype).toBe('fp64'); }); }); describe('factory methods', () => { it('should create zeros tensor', () => { const t = Tensor.zeros([2, 3]); expect(t.shape).toEqual([2, 3]); expect(t.size).toBe(6); expect(Array.from(t.data).every(v => v === 0)).toBe(true); }); it('should create ones tensor', () => { const t = Tensor.ones([3, 2]); expect(t.shape).toEqual([3, 2]); expect(Array.from(t.data).every(v => v === 1)).toBe(true); }); it('should create random tensor', () => { const t = Tensor.random([10, 10]); expect(t.shape).toEqual([10, 10]); expect(t.size).toBe(100); // Values should be in [0, 1) expect(Array.from(t.data).every(v => v >= 0 && v < 1)).toBe(true); }); it('should create randn tensor', () => { const t = Tensor.randn([100]); expect(t.shape).toEqual([100]); // Check approximate normal distribution properties const values = Array.from(t.data); const mean = values.reduce((a, b) => a + b) / values.length; expect(Math.abs(mean)).toBeLessThan(0.5); // Should be close to 0 }); it('should support different dtypes', () => { expect(Tensor.zeros([2, 2], 'fp64').dtype).toBe('fp64'); expect(Tensor.zeros([2, 2], 'fp32').dtype).toBe('fp32'); expect(Tensor.zeros([2, 2], 'int8').dtype).toBe('int8'); }); }); describe('operations', () => { it('should reshape tensor', () => { const t = Tensor.from([1, 2, 3, 4, 5, 6]); const reshaped = t.reshape([2, 3]); expect(reshaped.shape).toEqual([2, 3]); expect(reshaped.size).toBe(6); }); it('should throw on invalid reshape', () => { const t = Tensor.from([1, 2, 3, 4]); expect(() => t.reshape([2, 3])).toThrow(); }); it('should convert dtype', () => { const t = Tensor.from([1.5, 2.5, 3.5], 'fp32'); const converted = t.to('fp64'); expect(converted.dtype).toBe('fp64'); expect(converted.data).toBeInstanceOf(Float64Array); }); it('should get element at index', () => { const t = Tensor.from([[1, 2, 3], [4, 5, 6]]); expect(t.get(0, 0)).toBe(1); expect(t.get(0, 2)).toBe(3); expect(t.get(1, 1)).toBe(5); }); it('should set element at index', () => { const t = Tensor.from([[1, 2], [3, 4]]); t.set(99, 0, 1); expect(t.get(0, 1)).toBe(99); }); it('should convert to array', () => { const t = Tensor.from([[1, 2], [3, 4]]); const arr = t.toArray(); expect(arr).toEqual([[1, 2], [3, 4]]); }); }); describe('properties', () => { it('should calculate byteSize correctly', () => { const fp32 = Tensor.zeros([10], 'fp32'); expect(fp32.byteSize).toBe(40); // 10 * 4 bytes const fp64 = Tensor.zeros([10], 'fp64'); expect(fp64.byteSize).toBe(80); // 10 * 8 bytes }); }); describe('serialization', () => { it('should serialize and deserialize tensor', () => { const original = Tensor.from([[1, 2, 3], [4, 5, 6]]); const serialized = original.serialize(); expect(serialized).toHaveProperty('data'); expect(serialized).toHaveProperty('shape'); expect(serialized).toHaveProperty('dtype'); expect(serialized.shape).toEqual([2, 3]); const restored = Tensor.deserialize(serialized); expect(restored.shape).toEqual(original.shape); expect(restored.dtype).toBe(original.dtype); expect(Array.from(restored.data)).toEqual(Array.from(original.data)); }); it('should preserve dtype during serialization', () => { const fp64Tensor = Tensor.zeros([3], 'fp64'); fp64Tensor.set(1.234567890123, 0); const serialized = fp64Tensor.serialize(); const restored = Tensor.deserialize(serialized); expect(restored.dtype).toBe('fp64'); }); }); describe('edge cases', () => { it('should handle scalar-like tensor', () => { const t = Tensor.from([42]); expect(t.shape).toEqual([1]); expect(t.size).toBe(1); }); it('should handle large tensors', () => { const t = Tensor.zeros([1000, 1000]); expect(t.size).toBe(1000000); }); it('should throw on out-of-bounds access', () => { const t = Tensor.from([1, 2, 3]); expect(() => t.get(5)).toThrow(); }); it('should throw on wrong number of indices', () => { const t = Tensor.from([[1, 2], [3, 4]]); expect(() => t.get(0)).toThrow(); }); }); });