Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
- """Create a TensorProto.
- Args:
- values: Values to put in the TensorProto.
- dtype: Optional tensor_pb2 DataType value.
- shape: List of integers representing the dimensions of tensor.
- verify_shape: Boolean that enables verification of a shape of values.
- Returns:
- A TensorProto. Depending on the type, it may contain data in the
- "tensor_content" attribute, which is not directly useful to Python programs.
- To access the values you should convert the proto back to a numpy ndarray
- with tensor_util.MakeNdarray(proto).
- Raises:
- TypeError: if unsupported types are provided.
- ValueError: if arguments have inappropriate values or if verify_shape is
- True and shape of values is not equals to a shape from the argument.
- make_tensor_proto accepts "values" of a python scalar, a python list, a
- numpy ndarray, or a numpy scalar.
- If "values" is a python scalar or a python list, make_tensor_proto
- first convert it to numpy ndarray. If dtype is None, the
- conversion tries its best to infer the right numpy data
- type. Otherwise, the resulting numpy array has a compatible data
- type with the given dtype.
- In either case above, the numpy ndarray (either the caller provided
- or the auto converted) must have the compatible type with dtype.
- make_tensor_proto then converts the numpy array to a tensor proto.
- If "shape" is None, the resulting tensor proto represents the numpy
- array precisely.
- Otherwise, "shape" specifies the tensor's shape and the numpy array
- can not have more elements than what "shape" specifies.
- """
- if dtype:
- dtype = dtypes.as_dtype(dtype)
- is_quantized = (dtype in [dtypes.qint8, dtypes.quint8, dtypes.qint16,
- dtypes.quint16, dtypes.qint32])
- # We first convert value to a numpy array or scalar.
- if isinstance(values, (np.ndarray, np.generic)):
- if dtype:
- nparray = values.astype(dtype.as_numpy_dtype)
- else:
- nparray = values
- else:
- if values is None:
- raise ValueError("None values not supported.")
- # if dtype is provided, forces numpy array to be the type
- # provided if possible.
- if dtype and dtype.is_numpy_compatible:
- np_dt = dtype.as_numpy_dtype
- else:
- np_dt = None
- if np.prod(shape) == 0:
- nparray = np.empty(shape, dtype=np_dt)
- else:
- _AssertCompatible(values, dtype)
- nparray = np.array(values, dtype=np_dt)
- # check to them.
- # We need to pass in quantized values as tuples, so don't apply the shape
- if (list(nparray.shape) != _GetDenseDimensions(values) and
- not is_quantized):
- raise ValueError("""Argument must be a dense tensor: %s"""
- """ - got shape %s, but wanted %s.""" % (
- values, list(nparray.shape),
- _GetDenseDimensions(values)))
- # python/numpy default float type is float64. We prefer float32 instead.
- if (nparray.dtype == np.float64) and dtype is None:
- nparray = nparray.astype(np.float32)
- # python/numpy default int type is int64. We prefer int32 instead.
- elif (nparray.dtype == np.int64) and dtype is None:
- downcasted_array = nparray.astype(np.int32)
- # Do not down cast if it leads to precision loss.
- if np.array_equal(downcasted_array, nparray):
- nparray = downcasted_array
- # if dtype is provided, it must be compatible with what numpy
- # conversion says.
- numpy_dtype = dtypes.as_dtype(nparray.dtype)
- if numpy_dtype is None:
- raise TypeError("Unrecognized data type: %s" % nparray.dtype)
- # If dtype was specified and is a quantized type, we convert
- # numpy_dtype back into the quantized version.
- if is_quantized:
- numpy_dtype = dtype
- if dtype is not None and (not hasattr(dtype, "base_dtype") or
- dtype.base_dtype != numpy_dtype.base_dtype):
- raise TypeError("Incompatible types: %s vs. %s" % (dtype, nparray.dtype))
- # If shape is not given, get the shape from the numpy array.
- if shape is None:
- shape = nparray.shape
- is_same_size = True
- shape_size = nparray.size
- else:
- shape = [int(dim) for dim in shape]
- shape_size = np.prod(shape)
- is_same_size = shape_size == nparray.size
- if verify_shape:
- if not nparray.shape == tuple(shape):
- raise TypeError("Expected Tensor's shape: %s, got %s." %
- (tuple(shape), nparray.shape))
- if nparray.size > shape_size:
- raise ValueError(
- "Too many elements provided. Needed at most %d, but received %d" %
- (shape_size, nparray.size))
- tensor_proto = tensor_pb2.TensorProto(
- dtype=numpy_dtype.as_datatype_enum,
- tensor_shape=tensor_shape.as_shape(shape).as_proto())
- if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
- if nparray.size * nparray.itemsize >= (1 << 31):
- raise ValueError(
- "Cannot create a tensor proto whose content is larger than 2GB.")
- tensor_proto.tensor_content = nparray.tostring()
- return tensor_proto
- # If we were not given values as a numpy array, compute the proto_values
- # from the given values directly, to avoid numpy trimming nulls from the
- # strings. Since values could be a list of strings, or a multi-dimensional
- # list of lists that might or might not correspond to the given shape,
- # we flatten it conservatively.
- if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
- proto_values = _FlattenToStrings(values)
- tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
- return tensor_proto
- # TensorFlow expects C order (a.k.a., eigen row major).
- proto_values = nparray.ravel()
- append_fn = GetNumpyAppendFn(proto_values.dtype)
- if append_fn is None:
- raise TypeError("Element type not supported in TensorProto: %s" %
- numpy_dtype.name)
- append_fn(tensor_proto, proto_values)
- return tensor_proto
Add Comment
Please, Sign In to add comment