Guest User

Untitled

a guest
Dec 8th, 2019
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.20 KB | None | 0 0
  1. const std = @import("std");
  2.  
  3. pub fn Constant(comptime T: type) type {
  4. return struct {
  5. value: T,
  6. };
  7. }
  8.  
  9. pub fn Operation(comptime T: type) type {
  10. return struct {
  11. left: Tensor(T),
  12. right: Tensor(T),
  13. };
  14. }
  15.  
  16. pub fn Tensor(comptime T: type) type {
  17. return union(enum) {
  18. constant: *const Constant(T),
  19. operation: *const Operation(T),
  20. };
  21. }
  22.  
  23. pub fn Graph(comptime T: type) type {
  24. return struct {
  25. constants: std.ArrayList(Constant(T)),
  26. operations: std.ArrayList(Operation(T)),
  27.  
  28. pub const elementType: type = T;
  29.  
  30. pub fn init(allocator: *std.mem.Allocator) Graph(T) {
  31. return .{
  32. .constants = std.ArrayList(Constant(T)).init(allocator),
  33. .operations = std.ArrayList(Operation(T)).init(allocator),
  34. };
  35. }
  36. };
  37. }
  38.  
  39. pub fn constant(graph: var, value: var) !Tensor(@typeOf(graph.*).elementType) {
  40. const T = @typeOf(graph.*).elementType;
  41. const c = Constant(T){ .value = value };
  42. try graph.constants.append(c);
  43. return Tensor(T){ .constant = &c };
  44. }
  45.  
  46. test "constant" {
  47. var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
  48. defer arena.deinit();
  49. const allocator = &arena.allocator;
  50.  
  51. var graph = Graph(f64).init(allocator);
  52. const x = try constant(&graph, 5);
  53. const y = try constant(&graph, 10);
  54. std.testing.expectEqual(x.constant.value, 5);
  55. std.testing.expectEqual(y.constant.value, 10);
  56. }
  57.  
  58. pub fn add(graph: var, x: var, y: @typeOf(x)) !@typeOf(x) {
  59. const T = @typeOf(graph.*).elementType;
  60. const o = Operation(T){ .left = x, .right = y };
  61. try graph.operations.append(o);
  62. return Tensor(T){ .operation = &o };
  63. }
  64.  
  65. test "add" {
  66. var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
  67. defer arena.deinit();
  68. const allocator = &arena.allocator;
  69.  
  70. var graph = Graph(f64).init(allocator);
  71. const x = try constant(&graph, 5);
  72. const y = try constant(&graph, 10);
  73. const z = try add(&graph, x, y);
  74. std.testing.expectEqual(x.constant, z.operation.left.constant);
  75. std.testing.expectEqual(y.constant, z.operation.right.constant);
  76. }
Add Comment
Please, Sign In to add comment