import 'dart:typed_data'; import 'package:flutter_test/flutter_test.dart'; import 'package:synor_compute/synor_compute.dart'; void main() { group('Tensor Creation', () { test('creates tensor with shape and data', () { final tensor = Tensor( shape: [2, 3], data: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ); expect(tensor.shape, equals([2, 3])); expect(tensor.size, equals(6)); expect(tensor.ndim, equals(2)); }); test('throws on data-shape mismatch', () { expect( () => Tensor( shape: [2, 3], data: [1.0, 2.0, 3.0], // Only 3 elements for shape [2, 3] ), throwsA(isA()), ); }); test('creates tensor from typed data', () { final data = Float64List.fromList([1.0, 2.0, 3.0, 4.0]); final tensor = Tensor.fromTypedData( shape: [2, 2], data: data, ); expect(tensor.shape, equals([2, 2])); expect(tensor.size, equals(4)); }); }); group('Tensor Factory Methods', () { test('creates zeros tensor', () { final tensor = Tensor.zeros([3, 4]); expect(tensor.shape, equals([3, 4])); expect(tensor.size, equals(12)); expect(tensor.data.every((v) => v == 0.0), isTrue); }); test('creates ones tensor', () { final tensor = Tensor.ones([2, 2]); expect(tensor.shape, equals([2, 2])); expect(tensor.data.every((v) => v == 1.0), isTrue); }); test('creates full tensor with value', () { final tensor = Tensor.full([3, 3], 5.0); expect(tensor.data.every((v) => v == 5.0), isTrue); }); test('creates random tensor', () { final tensor = Tensor.rand([10, 10]); expect(tensor.shape, equals([10, 10])); expect(tensor.size, equals(100)); // Values should be in [0, 1) expect(tensor.data.every((v) => v >= 0 && v < 1), isTrue); }); test('creates randn tensor with normal distribution', () { final tensor = Tensor.randn([1000]); // Mean should be close to 0, std close to 1 expect(tensor.mean().abs(), lessThan(0.2)); expect(tensor.std(), closeTo(1.0, 0.2)); }); test('creates identity matrix', () { final tensor = Tensor.eye(3); expect(tensor.shape, equals([3, 3])); expect(tensor.at([0, 0]), equals(1.0)); expect(tensor.at([1, 1]), equals(1.0)); expect(tensor.at([2, 2]), equals(1.0)); expect(tensor.at([0, 1]), equals(0.0)); }); test('creates linspace tensor', () { final tensor = Tensor.linspace(0.0, 10.0, 11); expect(tensor.shape, equals([11])); expect(tensor[0], equals(0.0)); expect(tensor[10], equals(10.0)); expect(tensor[5], equals(5.0)); }); test('creates arange tensor', () { final tensor = Tensor.arange(0.0, 5.0, step: 1.0); expect(tensor.shape, equals([5])); expect(tensor[0], equals(0.0)); expect(tensor[4], equals(4.0)); }); }); group('Tensor Operations', () { test('reshapes tensor', () { final tensor = Tensor( shape: [6], data: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ); final reshaped = tensor.reshape([2, 3]); expect(reshaped.shape, equals([2, 3])); expect(reshaped.size, equals(6)); }); test('throws on invalid reshape', () { final tensor = Tensor.zeros([4]); expect( () => tensor.reshape([2, 3]), throwsA(isA()), ); }); test('flattens tensor', () { final tensor = Tensor.zeros([2, 3, 4]); final flat = tensor.flatten(); expect(flat.shape, equals([24])); }); test('transposes 2D tensor', () { final tensor = Tensor( shape: [2, 3], data: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ); final transposed = tensor.transpose(); expect(transposed.shape, equals([3, 2])); expect(transposed.at([0, 0]), equals(1.0)); expect(transposed.at([0, 1]), equals(4.0)); }); test('accesses element at index for 1D tensor', () { final tensor = Tensor(shape: [4], data: [1.0, 2.0, 3.0, 4.0]); expect(tensor[0], equals(1.0)); expect(tensor[3], equals(4.0)); }); test('accesses element at multi-dimensional index', () { final tensor = Tensor( shape: [2, 3], data: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ); expect(tensor.at([0, 0]), equals(1.0)); expect(tensor.at([0, 2]), equals(3.0)); expect(tensor.at([1, 1]), equals(5.0)); }); test('throws on out-of-bounds index', () { final tensor = Tensor.zeros([2, 3]); expect( () => tensor.at([2, 0]), throwsA(isA()), ); }); }); group('Tensor Reductions', () { test('calculates sum', () { final tensor = Tensor(shape: [4], data: [1.0, 2.0, 3.0, 4.0]); expect(tensor.sum(), equals(10.0)); }); test('calculates mean', () { final tensor = Tensor(shape: [4], data: [1.0, 2.0, 3.0, 4.0]); expect(tensor.mean(), equals(2.5)); }); test('calculates std', () { final tensor = Tensor(shape: [4], data: [1.0, 2.0, 3.0, 4.0]); expect(tensor.std(), closeTo(1.118, 0.001)); }); test('finds min', () { final tensor = Tensor(shape: [4], data: [3.0, 1.0, 4.0, 2.0]); expect(tensor.min(), equals(1.0)); }); test('finds max', () { final tensor = Tensor(shape: [4], data: [3.0, 1.0, 4.0, 2.0]); expect(tensor.max(), equals(4.0)); }); test('finds argmin', () { final tensor = Tensor(shape: [4], data: [3.0, 1.0, 4.0, 2.0]); expect(tensor.argmin(), equals(1)); }); test('finds argmax', () { final tensor = Tensor(shape: [4], data: [3.0, 1.0, 4.0, 2.0]); expect(tensor.argmax(), equals(2)); }); }); group('Tensor Element-wise Operations', () { test('adds tensors', () { final a = Tensor(shape: [3], data: [1.0, 2.0, 3.0]); final b = Tensor(shape: [3], data: [4.0, 5.0, 6.0]); final result = a.add(b); expect(result.data, equals(Float64List.fromList([5.0, 7.0, 9.0]))); }); test('subtracts tensors', () { final a = Tensor(shape: [3], data: [5.0, 7.0, 9.0]); final b = Tensor(shape: [3], data: [1.0, 2.0, 3.0]); final result = a.sub(b); expect(result.data, equals(Float64List.fromList([4.0, 5.0, 6.0]))); }); test('multiplies tensors', () { final a = Tensor(shape: [3], data: [2.0, 3.0, 4.0]); final b = Tensor(shape: [3], data: [1.0, 2.0, 3.0]); final result = a.mul(b); expect(result.data, equals(Float64List.fromList([2.0, 6.0, 12.0]))); }); test('divides tensors', () { final a = Tensor(shape: [3], data: [6.0, 8.0, 9.0]); final b = Tensor(shape: [3], data: [2.0, 4.0, 3.0]); final result = a.div(b); expect(result.data, equals(Float64List.fromList([3.0, 2.0, 3.0]))); }); test('adds scalar', () { final tensor = Tensor(shape: [3], data: [1.0, 2.0, 3.0]); final result = tensor.addScalar(10.0); expect(result.data, equals(Float64List.fromList([11.0, 12.0, 13.0]))); }); test('multiplies by scalar', () { final tensor = Tensor(shape: [3], data: [1.0, 2.0, 3.0]); final result = tensor.mulScalar(2.0); expect(result.data, equals(Float64List.fromList([2.0, 4.0, 6.0]))); }); test('throws on shape mismatch', () { final a = Tensor.zeros([2, 3]); final b = Tensor.zeros([3, 2]); expect(() => a.add(b), throwsA(isA())); }); }); group('Tensor Activations', () { test('applies relu', () { final tensor = Tensor(shape: [5], data: [-2.0, -1.0, 0.0, 1.0, 2.0]); final result = tensor.relu(); expect(result.data, equals(Float64List.fromList([0.0, 0.0, 0.0, 1.0, 2.0]))); }); test('applies sigmoid', () { final tensor = Tensor(shape: [1], data: [0.0]); final result = tensor.sigmoid(); expect(result[0], closeTo(0.5, 0.001)); }); test('applies tanh', () { final tensor = Tensor(shape: [1], data: [0.0]); final result = tensor.tanh(); expect(result[0], closeTo(0.0, 0.001)); }); test('applies softmax on 1D tensor', () { final tensor = Tensor(shape: [3], data: [1.0, 2.0, 3.0]); final result = tensor.softmax(); expect(result.sum(), closeTo(1.0, 0.001)); expect(result[2], greaterThan(result[1])); expect(result[1], greaterThan(result[0])); }); }); group('Tensor Serialization', () { test('serializes to JSON', () { final tensor = Tensor( shape: [2, 2], data: [1.0, 2.0, 3.0, 4.0], ); final json = tensor.toJson(); expect(json['shape'], equals([2, 2])); expect(json['dtype'], equals('float64')); expect(json['data'], isA()); }); test('deserializes from JSON with base64 data', () { final original = Tensor( shape: [2, 3], data: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ); final json = original.toJson(); final restored = Tensor.fromJson(json); expect(restored.shape, equals(original.shape)); expect(restored.dtype, equals(original.dtype)); for (var i = 0; i < original.size; i++) { expect(restored.data[i], closeTo(original.data[i], 0.0001)); } }); test('deserializes from JSON with list data', () { final json = { 'shape': [2, 2], 'data': [ [1.0, 2.0], [3.0, 4.0] ], 'dtype': 'float64', }; final tensor = Tensor.fromJson(json); expect(tensor.shape, equals([2, 2])); expect(tensor.at([0, 0]), equals(1.0)); expect(tensor.at([1, 1]), equals(4.0)); }); test('converts to nested list', () { final tensor = Tensor( shape: [2, 3], data: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ); final nested = tensor.toNestedList(); expect(nested, equals([ [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], ])); }); }); group('Tensor Properties', () { test('calculates nbytes correctly', () { final tensor = Tensor.zeros([10]); // Float64List uses 8 bytes per element expect(tensor.nbytes, equals(80)); }); test('equality works correctly', () { final a = Tensor(shape: [2], data: [1.0, 2.0]); final b = Tensor(shape: [2], data: [1.0, 2.0]); final c = Tensor(shape: [2], data: [1.0, 3.0]); expect(a, equals(b)); expect(a, isNot(equals(c))); }); test('toString provides useful info', () { final small = Tensor(shape: [2], data: [1.0, 2.0]); final large = Tensor.zeros([100, 100]); expect(small.toString(), contains('Tensor')); expect(large.toString(), contains('shape')); }); }); }